mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-03 22:04:18 +08:00
Compare commits
179 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
530c958afc | ||
|
|
57d861b578 | ||
|
|
99664298b9 | ||
|
|
a6fe5a7022 | ||
|
|
1918dad602 | ||
|
|
69399c291e | ||
|
|
9ec33ce320 | ||
|
|
c35d3aff7d | ||
|
|
2a5744d1c4 | ||
|
|
825511506b | ||
|
|
5a98a701cb | ||
|
|
dd1fa35c73 | ||
|
|
fb572fa849 | ||
|
|
c0a473ed19 | ||
|
|
030641adc6 | ||
|
|
445ef49dc8 | ||
|
|
32d4c60541 | ||
|
|
23f865be07 | ||
|
|
5d55325c12 | ||
|
|
900330509a | ||
|
|
cfb682ae3c | ||
|
|
abae90b16d | ||
|
|
470fc37f26 | ||
|
|
7a7caef1a6 | ||
|
|
a6aecb5d89 | ||
|
|
4a004f9aa1 | ||
|
|
1a6feae23b | ||
|
|
af5b2fa2c9 | ||
|
|
eeec45274b | ||
|
|
2b48c853fe | ||
|
|
c47f696691 | ||
|
|
9a8e4c8e15 | ||
|
|
24aab9a658 | ||
|
|
afdaaffac5 | ||
|
|
fe721116e2 | ||
|
|
8e0a834daa | ||
|
|
c9fca1561c | ||
|
|
5eb2dfd822 | ||
|
|
0b837c3f80 | ||
|
|
a6cfc12443 | ||
|
|
f6d64dd850 | ||
|
|
eed62caa78 | ||
|
|
204d41d6f3 | ||
|
|
858df0548e | ||
|
|
b3da021803 | ||
|
|
d234f826f4 | ||
|
|
231b69ecf8 | ||
|
|
0a08913677 | ||
|
|
49d32813ea | ||
|
|
c5d57e97b1 | ||
|
|
da8f7539a1 | ||
|
|
64a68f1176 | ||
|
|
1199d7cc3c | ||
|
|
8a827d2acb | ||
|
|
0e8a943d7f | ||
|
|
4f62658440 | ||
|
|
6e7c3d5f6a | ||
|
|
d5062db9b6 | ||
|
|
a6ad006a49 | ||
|
|
57d593fa17 | ||
|
|
f38b5ae870 | ||
|
|
418b3ca13c | ||
|
|
09bfa85e69 | ||
|
|
62b132208b | ||
|
|
fc28f4f74e | ||
|
|
f79a52f839 | ||
|
|
94d1041961 | ||
|
|
ada32d526a | ||
|
|
ef1e38aba1 | ||
|
|
60b2d59e25 | ||
|
|
e18aa73456 | ||
|
|
24747a5f09 | ||
|
|
621dac22dc | ||
|
|
23d7004b60 | ||
|
|
c3b3d34127 | ||
|
|
18a166afb0 | ||
|
|
a41447a96d | ||
|
|
df8d543539 | ||
|
|
5ecce8e0fe | ||
|
|
00f423a622 | ||
|
|
05ce04de69 | ||
|
|
cd5549e1aa | ||
|
|
f573c0255a | ||
|
|
060d7fffe6 | ||
|
|
38dbcd1643 | ||
|
|
241d97027c | ||
|
|
d18689fe9f | ||
|
|
b72298fef4 | ||
|
|
2d73503b00 | ||
|
|
fb106cd975 | ||
|
|
5f74aacfdf | ||
|
|
d9729a8a89 | ||
|
|
0665d5227d | ||
|
|
85a89669ff | ||
|
|
a2a77e607c | ||
|
|
258df26399 | ||
|
|
df9c980ca1 | ||
|
|
117f327e7b | ||
|
|
d599ba6be3 | ||
|
|
8484651fdd | ||
|
|
aab38648f8 | ||
|
|
9d4b45cf35 | ||
|
|
484e5cdc42 | ||
|
|
e37e11bf57 | ||
|
|
7661b71fcc | ||
|
|
b3a4306332 | ||
|
|
6aab140ec2 | ||
|
|
e260ad02bf | ||
|
|
4becc8d4d4 | ||
|
|
67f87989db | ||
|
|
17738b39a7 | ||
|
|
1e5312f96b | ||
|
|
548e69d87f | ||
|
|
90161a1f47 | ||
|
|
9ea3452b17 | ||
|
|
11e45fca37 | ||
|
|
c85fe979e5 | ||
|
|
a47edf1661 | ||
|
|
814a2e66c0 | ||
|
|
a7d548a849 | ||
|
|
b6a54190ed | ||
|
|
920228d3aa | ||
|
|
f1f568afca | ||
|
|
30bf666a57 | ||
|
|
c65d5244d6 | ||
|
|
4ad18e43ef | ||
|
|
f17cd66127 | ||
|
|
e1c068ed9e | ||
|
|
b86eac839d | ||
|
|
83252cbf33 | ||
|
|
12f6665519 | ||
|
|
1ff494416b | ||
|
|
8ec1d16e9d | ||
|
|
f13a4fba5f | ||
|
|
d4a3ed3a57 | ||
|
|
a6a1e7fb52 | ||
|
|
c01bc242aa | ||
|
|
ab06627d3f | ||
|
|
631d054d9e | ||
|
|
d835085e61 | ||
|
|
7c3ebe7e8b | ||
|
|
7e76d07e28 | ||
|
|
d21fb6c455 | ||
|
|
56f6f5e198 | ||
|
|
929592bbc4 | ||
|
|
2225a40bbe | ||
|
|
3480fa3b0f | ||
|
|
d7113f5fc4 | ||
|
|
2072f54ca1 | ||
|
|
7c9b721164 | ||
|
|
83ce50975a | ||
|
|
7da9110704 | ||
|
|
e9d19de7c6 | ||
|
|
e822831178 | ||
|
|
775930edce | ||
|
|
cb40848c04 | ||
|
|
7098c8755f | ||
|
|
705d602dee | ||
|
|
cd257a9406 | ||
|
|
cd54650431 | ||
|
|
a5602c602e | ||
|
|
dd70fd4c44 | ||
|
|
dbe50628b3 | ||
|
|
83ed0527d3 | ||
|
|
ab31f4bb98 | ||
|
|
734a8c4bc4 | ||
|
|
fea3af4692 | ||
|
|
9302cf295e | ||
|
|
b4f040e77a | ||
|
|
defabf4355 | ||
|
|
f3ed3168e4 | ||
|
|
01765b1731 | ||
|
|
f83f0fa768 | ||
|
|
a7085964e8 | ||
|
|
d3cd2856b7 | ||
|
|
353d22cc70 | ||
|
|
eb96474c19 | ||
|
|
0c48a2d74d | ||
|
|
1b23d574a5 |
51
.env.example
51
.env.example
@@ -1,18 +1,28 @@
|
||||
# MySQL数据库配置
|
||||
# 数据库配置
|
||||
DATABASE_TYPE=mysql
|
||||
#SQLITE_DATABASE=default_db
|
||||
MYSQL_HOST=gemini-balance-mysql
|
||||
#MYSQL_SOCKET=/run/mysqld/mysqld.sock
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_USER=gemini
|
||||
MYSQL_PASSWORD=change_me
|
||||
MYSQL_DATABASE=default_db
|
||||
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
||||
ALLOWED_TOKENS=["sk-123456"]
|
||||
# AUTH_TOKEN=sk-123456
|
||||
AUTH_TOKEN=sk-123456
|
||||
# For Vertex AI Platform API Keys
|
||||
VERTEX_API_KEYS=["AQ.Abxxxxxxxxxxxxxxxxxxx"]
|
||||
# For Vertex AI Platform Express API Base URL
|
||||
VERTEX_EXPRESS_BASE_URL=https://aiplatform.googleapis.com/v1beta1/publishers/google
|
||||
TEST_MODEL=gemini-1.5-flash
|
||||
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
||||
# 是否启用网址上下文,默认启用
|
||||
URL_CONTEXT_ENABLED=true
|
||||
URL_CONTEXT_MODELS=["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
SHOW_SEARCH_LINK=true
|
||||
SHOW_THINKING_PROCESS=true
|
||||
@@ -23,6 +33,11 @@ CHECK_INTERVAL_HOURS=1
|
||||
TIMEZONE=Asia/Shanghai
|
||||
# 请求超时时间(秒)
|
||||
TIME_OUT=300
|
||||
# 代理服务器配置 (支持 http 和 socks5)
|
||||
# 示例: PROXIES=["http://user:pass@host:port", "socks5://host:port"]
|
||||
PROXIES=[]
|
||||
# 对同一个API_KEY使用代理列表中固定的IP策略
|
||||
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY=true
|
||||
#########################image_generate 相关配置###########################
|
||||
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
|
||||
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||
@@ -31,6 +46,7 @@ SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
PICGO_API_KEY=xxxx
|
||||
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
|
||||
##########################################################################
|
||||
#########################stream_optimizer 相关配置########################
|
||||
STREAM_OPTIMIZER_ENABLED=false
|
||||
@@ -43,4 +59,35 @@ STREAM_CHUNK_SIZE=5
|
||||
######################### 日志配置 #######################################
|
||||
# 日志级别 (debug, info, warning, error, critical),默认为 info
|
||||
LOG_LEVEL=info
|
||||
# 是否开启自动删除错误日志
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED=true
|
||||
# 自动删除多少天前的错误日志 (1, 7, 30)
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS=7
|
||||
# 是否开启自动删除请求日志
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED=false
|
||||
# 自动删除多少天前的请求日志 (1, 7, 30)
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS=30
|
||||
##########################################################################
|
||||
|
||||
# 假流式配置 (Fake Streaming Configuration)
|
||||
# 是否启用假流式输出
|
||||
FAKE_STREAM_ENABLED=True
|
||||
# 假流式发送空数据的间隔时间(秒)
|
||||
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS=5
|
||||
|
||||
# 安全设置 (JSON 字符串格式)
|
||||
# 注意:这里的示例值可能需要根据实际模型支持情况调整
|
||||
SAFETY_SETTINGS=[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]
|
||||
URL_NORMALIZATION_ENABLED=false
|
||||
# tts配置
|
||||
TTS_MODEL=gemini-2.5-flash-preview-tts
|
||||
TTS_VOICE_NAME=Zephyr
|
||||
TTS_SPEED=normal
|
||||
#########################Files API 相关配置########################
|
||||
# 是否启用文件过期自动清理
|
||||
FILES_CLEANUP_ENABLED=true
|
||||
# 文件过期清理间隔(小时)
|
||||
FILES_CLEANUP_INTERVAL_HOURS=1
|
||||
# 是否启用用户文件隔离(每个用户只能看到自己上传的文件)
|
||||
FILES_USER_ISOLATION_ENABLED=true
|
||||
##########################################################################
|
||||
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@@ -3,7 +3,7 @@ name: Publish Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
- "v*" # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
|
||||
jobs:
|
||||
update-release-draft:
|
||||
@@ -15,8 +15,17 @@ jobs:
|
||||
# Step 1: 检出代码库
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# Step 2: 自动生成 Release
|
||||
# Step 2: 自动生成 Release Notes
|
||||
- name: Generate release notes
|
||||
id: changelog
|
||||
uses: mikepenz/release-changelog-builder-action@v4
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Step 3: 自动生成 Release
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
@@ -25,15 +34,16 @@ jobs:
|
||||
with:
|
||||
tag_name: ${{ github.ref_name }}
|
||||
release_name: ${{ github.ref_name }}
|
||||
body: ${{ steps.changelog.outputs.changelog }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
|
||||
# Step 3: 可选,构建zip文件
|
||||
|
||||
# Step 4: 可选,构建zip文件
|
||||
- name: Create ZIP file
|
||||
run: |
|
||||
zip -r gemini-balance.zip . -x "*.git*" "*.github*" "*.env*" "logs/*" "tests/*"
|
||||
|
||||
# Step 4: 可选,上传构建文件
|
||||
# Step 5: 可选,上传构建文件
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
@@ -41,5 +51,5 @@ jobs:
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./gemini-balance.zip # 替换为你的构建文件路径
|
||||
asset_name: gemini-balance.zip # 替换为你的文件名
|
||||
asset_name: gemini-balance.zip # 替换为你的文件名
|
||||
asset_content_type: application/zip
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -257,4 +257,5 @@ $RECYCLE.BIN/
|
||||
|
||||
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
|
||||
|
||||
tests/
|
||||
tests/
|
||||
default_db
|
||||
@@ -4,15 +4,10 @@ WORKDIR /app
|
||||
|
||||
# 复制所需文件到容器中
|
||||
COPY ./requirements.txt /app
|
||||
COPY ./VERSION /app
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY ./app /app/app
|
||||
ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
||||
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
364
README.md
364
README.md
@@ -1,229 +1,295 @@
|
||||
# Gemini Balance - Gemini API 代理和负载均衡器
|
||||
[Read this document in Chinese](README_ZH.md)
|
||||
|
||||
> ⚠️ 本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
# Gemini Balance - Gemini API Proxy and Load Balancer
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> ⚠️ This project is licensed under the CC BY-NC 4.0 (Attribution-NonCommercial) license. Any form of commercial resale service is prohibited. See the LICENSE file for details.
|
||||
|
||||
> I have never sold this service on any platform. If you encounter someone selling this service, they are definitely a reseller. Please be careful not to be deceived.
|
||||
|
||||
[](https://www.python.org/)
|
||||
[](https://fastapi.tiangolo.com/)
|
||||
[](https://www.uvicorn.org/)
|
||||
[](https://t.me/+soaHax5lyI0wZDVl)
|
||||
|
||||
## 项目简介
|
||||
> Telegram Group: <https://t.me/+soaHax5lyI0wZDVl>
|
||||
|
||||
Gemini Balance 是一个基于 Python FastAPI 构建的应用程序,旨在提供 Google Gemini API 的代理和负载均衡功能。它允许您管理多个 Gemini API Key,并通过简单的配置实现 Key 的轮询、认证、模型过滤和状态监控。此外,项目还集成了图像生成和多种图床上传功能,并支持 OpenAI API 格式的代理。
|
||||
## Project Introduction
|
||||
|
||||
**项目结构:**
|
||||
Gemini Balance is an application built with Python FastAPI, designed to provide proxy and load balancing functions for the Google Gemini API. It allows you to manage multiple Gemini API Keys and implement key rotation, authentication, model filtering, and status monitoring through simple configuration. Additionally, the project integrates image generation and multiple image hosting upload functions, and supports proxying in the OpenAI API format.
|
||||
|
||||
**Project Structure:**
|
||||
|
||||
```plaintext
|
||||
app/
|
||||
├── config/ # 配置管理
|
||||
├── core/ # 核心应用逻辑 (FastAPI 实例创建, 中间件等)
|
||||
├── database/ # 数据库模型和连接
|
||||
├── domain/ # 业务领域对象 (可选)
|
||||
├── exception/ # 自定义异常
|
||||
├── handler/ # 请求处理器 (可选, 或在 router 中处理)
|
||||
├── log/ # 日志配置
|
||||
├── main.py # 应用入口
|
||||
├── middleware/ # FastAPI 中间件
|
||||
├── router/ # API 路由 (Gemini, OpenAI, 状态页等)
|
||||
├── scheduler/ # 定时任务 (如 Key 状态检查)
|
||||
├── service/ # 业务逻辑服务 (聊天, Key 管理, 统计等)
|
||||
├── static/ # 静态文件 (CSS, JS)
|
||||
├── templates/ # HTML 模板 (如 Key 状态页)
|
||||
├── utils/ # 工具函数
|
||||
├── config/ # Configuration management
|
||||
├── core/ # Core application logic (FastAPI instance creation, middleware, etc.)
|
||||
├── database/ # Database models and connections
|
||||
├── domain/ # Business domain objects (optional)
|
||||
├── exception/ # Custom exceptions
|
||||
├── handler/ # Request handlers (optional, or handled in router)
|
||||
├── log/ # Logging configuration
|
||||
├── main.py # Application entry point
|
||||
├── middleware/ # FastAPI middleware
|
||||
├── router/ # API routes (Gemini, OpenAI, status page, etc.)
|
||||
├── scheduler/ # Scheduled tasks (e.g., Key status check)
|
||||
├── service/ # Business logic services (chat, Key management, statistics, etc.)
|
||||
├── static/ # Static files (CSS, JS)
|
||||
├── templates/ # HTML templates (e.g., Key status page)
|
||||
├── utils/ # Utility functions
|
||||
```
|
||||
|
||||
## ✨ 功能亮点
|
||||
## ✨ Feature Highlights
|
||||
|
||||
* **多 Key 负载均衡**: 支持配置多个 Gemini API Key (`API_KEYS`),自动按顺序轮询使用,提高可用性和并发能力。
|
||||
* **可视化配置即时生效**: 通过管理后台修改配置后,无需重启服务即可生效,切记要点击保存才会生效。
|
||||

|
||||
* **双协议API 兼容**: 同时支持 Gemini 和 OpenAI 格式的 CHAT API 请求转发。
|
||||
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling, improving availability and concurrency.
|
||||
* **Visual Configuration Takes Effect Immediately**: Configurations modified through the admin backend take effect without restarting the service. Remember to click save for changes to apply.
|
||||

|
||||
* **Dual Protocol API Compatibility**: Supports forwarding CHAT API requests in both Gemini and OpenAI formats.
|
||||
|
||||
```palintext
|
||||
```plaintext
|
||||
openai baseurl `http://localhost:8000(/hf)/v1`
|
||||
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
|
||||
```
|
||||
|
||||
* **支持图文对话和修改图片**: `IMAGE_MODELS`配置哪个模型可以图文对话和修图的功能,实际调用的时候,用 `配置模型-image`这个模型名对话使用该功能。
|
||||

|
||||

|
||||
* **支持联网搜索**: 支持联网搜索,`SEARCH_MODELS`配置哪些模型可以联网搜索,实际调用的时候,用 `配置模型-search`这个模型名对话使用该功能
|
||||

|
||||
* **Key 状态监控**: 提供 `/keys_status` 页面(需要认证),实时查看各 Key 的状态和使用情况。
|
||||

|
||||
* **详细的日志记录**: 提供详细的错误日志,方便排查。
|
||||

|
||||

|
||||

|
||||
* **支持自定义gemini代理**: 支持自定义gemini代理,比如自行在deno或者cloudflare上搭建gemini代理
|
||||
* **openai画图接口兼容**: 将`imagen-3.0-generate-002`模型接口改造成openai画图接口,支持客户端调用。
|
||||
* **灵活的添加密钥方式**: 灵活的添加密钥方式,采用正则匹配`gemini_key`,密钥去重
|
||||

|
||||
* **兼容openai格式embeddings接口**:完美适配openai格式的`embeddings`接口,可用于本地文档向量化。
|
||||
* **流式响应优化**: 可选的流式输出优化器 (`STREAM_OPTIMIZER_ENABLED`),改善长文本流式响应的体验。
|
||||
* **失败重试与 Key 管理**: 自动处理 API 请求失败,进行重试 (`MAX_RETRIES`),并在 Key 失效次数过多时自动禁用 (`MAX_FAILURES`),定时检查恢复 (`CHECK_INTERVAL_HOURS`)。
|
||||
* **Docker 支持**: 支持AMD,ARM架构的docker部署,也可自行构建docker镜像。
|
||||
>镜像地址: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **模型列表自动维护**: 支持openai和gemini模型列表获取,与newapi自动获取模型列表完美兼容,无需手动填写。
|
||||
* **支持移除不使用的模型**: 默认提供的模型太多,很多用不上,可以通过`FILTERED_MODELS`过滤掉。
|
||||
* **Supports Image-Text Chat and Image Modification**: `IMAGE_MODELS` configures which models can perform image-text chat and image editing. When actually calling, use the `configured_model-image` model name to use this feature.
|
||||

|
||||

|
||||
* **Supports Web Search**: Supports web search. `SEARCH_MODELS` configures which models can perform web searches. When actually calling, use the `configured_model-search` model name to use this feature.
|
||||

|
||||
* **Key Status Monitoring**: Provides a `/keys_status` page (requires authentication) to view the status and usage of each Key in real-time.
|
||||

|
||||
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
|
||||

|
||||

|
||||

|
||||
* **Support for Custom Gemini Proxy**: Supports custom Gemini proxies, such as those built on Deno or Cloudflare.
|
||||
* **OpenAI Image Generation API Compatibility**: Adapts the `imagen-3.0-generate-002` model interface to be compatible with the OpenAI image generation API, supporting client calls.
|
||||
* **Flexible Key Addition**: Flexible way to add keys using regex matching for `gemini_key`, with key deduplication.
|
||||

|
||||
* **OpenAI Format Embeddings API Compatibility**: Perfectly adapts to the OpenAI format `embeddings` interface, usable for local document vectorization.
|
||||
* **Streamlined Response Optimization**: Optional stream output optimizer (`STREAM_OPTIMIZER_ENABLED`) to improve the experience of long-text stream responses.
|
||||
* **Failure Retry and Key Management**: Automatically handles API request failures, retries (`MAX_RETRIES`), automatically disables Keys after too many failures (`MAX_FAILURES`), and periodically checks for recovery (`CHECK_INTERVAL_HOURS`).
|
||||
* **Docker Support**: Supports AMD and ARM architecture Docker deployments. You can also build your own Docker image.
|
||||
> Image address: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **Automatic Model List Maintenance**: Supports fetching OpenAI and Gemini model lists, perfectly compatible with NewAPI's automatic model list fetching, no manual entry required.
|
||||
* **Support for Removing Unused Models**: Too many default models are provided, many of which are not used. You can filter them out using `FILTERED_MODELS`.
|
||||
* **Proxy Support**: Supports configuring HTTP/SOCKS5 proxy servers (`PROXIES`) for accessing the Gemini API, convenient for use in special network environments. Supports batch adding proxies.
|
||||
|
||||
## 🚀 快速开始
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 自行构建 Docker (推荐)
|
||||
### Build Docker Yourself (Recommended)
|
||||
|
||||
#### a) dockerfile构建
|
||||
#### a) Build with Dockerfile
|
||||
|
||||
1. **构建镜像**:
|
||||
1. **Build Image**:
|
||||
|
||||
```bash
|
||||
docker build -t gemini-balance .
|
||||
```
|
||||
|
||||
2. **运行容器**:
|
||||
2. **Run Container**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env gemini-balance
|
||||
```
|
||||
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量。
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host.
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables.
|
||||
|
||||
#### b) 用现有的docker镜像部署
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
>
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
|
||||
> ```
|
||||
>
|
||||
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
|
||||
|
||||
1. **拉取镜像**:
|
||||
#### b) Deploy with an Existing Docker Image
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
1. **Pull Image**:
|
||||
|
||||
2. **运行容器**:
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
2. **Run Container**:
|
||||
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口 (根据需要调整)。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量 (确保 `.env` 文件存在于执行命令的目录)。
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
### 本地运行 (适用于开发和测试)
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host (adjust as needed).
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables (ensure the `.env` file exists in the directory where the command is executed).
|
||||
|
||||
如果您想在本地直接运行源代码进行开发或测试,请按照以下步骤操作:
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
>
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
|
||||
> ```
|
||||
>
|
||||
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
|
||||
|
||||
1. **确保已完成准备工作**:
|
||||
* 克隆仓库到本地。
|
||||
* 安装 Python 3.9 或更高版本。
|
||||
* 在项目根目录下创建并配置好 `.env` 文件 (参考前面的“配置环境变量”部分)。
|
||||
* 安装项目依赖:
|
||||
### Run Locally (Suitable for Development and Testing)
|
||||
|
||||
If you want to run the source code directly locally for development or testing, follow these steps:
|
||||
|
||||
1. **Ensure Prerequisites are Met**:
|
||||
* Clone the repository locally.
|
||||
* Install Python 3.9 or higher.
|
||||
* Create and configure the `.env` file in the project root directory (refer to the "Configure Environment Variables" section above).
|
||||
* Install project dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **启动应用**:
|
||||
在项目根目录下运行以下命令:
|
||||
2. **Start Application**:
|
||||
Run the following command in the project root directory:
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
* `app.main:app`: 指定 FastAPI 应用实例的位置 (`app` 模块中的 `main.py` 文件里的 `app` 对象)。
|
||||
* `--host 0.0.0.0`: 使应用可以从本地网络中的任何 IP 地址访问。
|
||||
* `--port 8000`: 指定应用监听的端口号 (您可以根据需要修改)。
|
||||
* `--reload`: 启用自动重载功能。当您修改代码时,服务会自动重启,非常适合开发环境 (生产环境请移除此选项)。
|
||||
* `app.main:app`: Specifies the location of the FastAPI application instance (the `app` object in the `main.py` file within the `app` module).
|
||||
* `--host 0.0.0.0`: Makes the application accessible from any IP address on the local network.
|
||||
* `--port 8000`: Specifies the port number the application listens on (you can change this as needed).
|
||||
* `--reload`: Enables automatic reloading. When you modify the code, the service will automatically restart, which is very suitable for development environments (remove this option in production environments).
|
||||
|
||||
3. **访问应用**:
|
||||
应用启动后,您可以通过浏览器或 API 工具访问 `http://localhost:8000` (或您指定的主机和端口)。
|
||||
3. **Access Application**:
|
||||
After the application starts, you can access `http://localhost:8000` (or the host and port you specified) through a browser or API tool.
|
||||
|
||||
### 完整配置项列表
|
||||
### Complete Configuration List
|
||||
|
||||
| 配置项 | 说明 | 默认值 |
|
||||
| :--------------------------- | :------------------------------------------------------- | :---------------------------------------------------- |
|
||||
| **数据库配置** | | |
|
||||
| `MYSQL_HOST` | 必填,MySQL 数据库主机地址 | `localhost` |
|
||||
| `MYSQL_PORT` | 必填,MySQL 数据库端口 | `3306` |
|
||||
| `MYSQL_USER` | 必填,MySQL 数据库用户名 | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | 必填,MySQL 数据库密码 | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | 必填,MySQL 数据库名称 | `defaultdb` |
|
||||
| **API 相关配置** | | |
|
||||
| `API_KEYS` | 必填,Gemini API 密钥列表,用于负载均衡 | `["your-gemini-api-key-1", "your-gemini-api-key-2"]` |
|
||||
| `ALLOWED_TOKENS` | 必填,允许访问的 Token 列表 | `["your-access-token-1", "your-access-token-2"]` |
|
||||
| `AUTH_TOKEN` | 可选,超级管理员token,具有所有权限,不填默认使用 ALLOWED_TOKENS 的第一个 | `""` |
|
||||
| `TEST_MODEL` | 可选,用于测试密钥是否可用的模型名 | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | 可选,支持绘图功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | 可选,支持搜索功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `FILTERED_MODELS` | 可选,被禁用的模型列表 | `["gemini-1.0-pro-vision-latest", ...]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | 可选,是否启用代码执行工具 | `false` |
|
||||
| `SHOW_SEARCH_LINK` | 可选,是否在响应中显示搜索结果链接 | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | 可选,是否显示模型思考过程 | `true` |
|
||||
| `THINKING_MODELS` | 可选,支持思考功能的模型列表 | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | 可选,思考功能预算映射 (模型名:预算值) | `{}` |
|
||||
| `BASE_URL` | 可选,Gemini API 基础 URL,默认无需修改 | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | 可选,允许单个key失败的次数 | `3` |
|
||||
| `MAX_RETRIES` | 可选,API 请求失败时的最大重试次数 | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | 可选,检查禁用 Key 是否恢复的时间间隔 (小时) | `1` |
|
||||
| `TIMEZONE` | 可选,应用程序使用的时区 | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | 可选,请求超时时间 (秒) | `300` |
|
||||
| `LOG_LEVEL` | 可选,日志级别,例如 DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 可选,付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 可选,图片生成模型 | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | 可选,图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `SMMS_SECRET_TOKEN` | 可选,SM.MS图床的API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | 可选,[PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | 可选,[CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| 可选,CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| **流式优化器相关** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | 可选,是否启用流式输出优化 | `false` |
|
||||
| `STREAM_MIN_DELAY` | 可选,流式输出最小延迟 | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | 可选,流式输出最大延迟 | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD`| 可选,短文本阈值 | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | 可选,长文本阈值 | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | 可选,流式输出块大小 | `5` |
|
||||
| Configuration Item | Description | Default Value |
|
||||
| :----------------------------- | :-------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Database Configuration** | | |
|
||||
| `DATABASE_TYPE` | Optional, database type, supports `mysql` or `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | Optional, required when using `sqlite`, SQLite database file path | `default_db` |
|
||||
| `MYSQL_HOST` | Required when using `mysql`, MySQL database host address | `localhost` |
|
||||
| `MYSQL_SOCKET` | Optional, MySQL database socket address | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | Required when using `mysql`, MySQL database port | `3306` |
|
||||
| `MYSQL_USER` | Required when using `mysql`, MySQL database username | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | Required when using `mysql`, MySQL database password | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | Required when using `mysql`, MySQL database name | `defaultdb` |
|
||||
| **API Related Configuration** | | |
|
||||
| `API_KEYS` | Required, list of Gemini API keys for load balancing | `["your-gemini-api-key-1", "your-gemini-api-key-2"]` |
|
||||
| `ALLOWED_TOKENS` | Required, list of tokens allowed to access | `["your-access-token-1", "your-access-token-2"]` |
|
||||
| `AUTH_TOKEN` | Optional, super admin token with all permissions, defaults to the first of `ALLOWED_TOKENS` if not set | `sk-123456` |
|
||||
| `TEST_MODEL` | Optional, model name used to test if a key is usable | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | Optional, list of models that support drawing functions | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | Optional, list of models that support search functions | `["gemini-2.0-flash-exp"]` |
|
||||
| `FILTERED_MODELS` | Optional, list of disabled models | `["gemini-1.0-pro-vision-latest", ...]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | Optional, whether to enable the code execution tool | `false` |
|
||||
| `SHOW_SEARCH_LINK` | Optional, whether to display search result links in the response | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | Optional, whether to display the model's thinking process | `true` |
|
||||
| `THINKING_MODELS` | Optional, list of models that support thinking functions | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | Optional, thinking function budget mapping (model_name:budget_value) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | Optional, whether to enable intelligent URL routing mapping | `false` |
|
||||
| `URL_CONTEXT_ENABLED` | Optional, whether to enable URL context understanding | `false` |
|
||||
| `URL_CONTEXT_MODELS` | Optional, list of models that support URL context understanding | `[]` |
|
||||
| `BASE_URL` | Optional, Gemini API base URL, no modification needed by default | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | Optional, number of times a single key is allowed to fail | `3` |
|
||||
| `MAX_RETRIES` | Optional, maximum number of retries for failed API requests | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | Optional, time interval (hours) to check if a disabled Key has recovered | `1` |
|
||||
| `TIMEZONE` | Optional, timezone used by the application | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | Optional, request timeout (seconds) | `300` |
|
||||
| `PROXIES` | Optional, list of proxy servers (e.g., `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
|
||||
| `LOG_LEVEL` | Optional, log level, e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | Optional, whether to enable automatic deletion of error logs | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | Optional, automatically delete error logs older than this many days (e.g., 1, 7, 30) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Optional, whether to enable automatic deletion of request logs | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | Optional, automatically delete request logs older than this many days (e.g., 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | Optional, safety settings (JSON string format), used to configure content safety thresholds. Example values may need adjustment based on actual model support. | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **TTS Related** | | |
|
||||
| `TTS_MODEL` | Optional, TTS model name | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | Optional, TTS voice name | `Zephyr` |
|
||||
| `TTS_SPEED` | Optional, TTS speed | `normal` |
|
||||
| **Image Generation Related** | | |
|
||||
| `PAID_KEY` | Optional, paid API Key for advanced features like image generation | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | Optional, image generation model | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | Optional, image upload provider: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `SMMS_SECRET_TOKEN` | Optional, API Token for SM.MS image hosting | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | Optional, API Key for [PicoGo](https://www.picgo.net/) image hosting | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | Optional, [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) image hosting upload address | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE` | Optional, authentication key for CloudFlare image hosting | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER` | Optional, upload folder path for CloudFlare image hosting | `""` |
|
||||
| **Stream Optimizer Related** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | Optional, whether to enable stream output optimization | `false` |
|
||||
| `STREAM_MIN_DELAY` | Optional, minimum delay for stream output | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | Optional, maximum delay for stream output | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD` | Optional, short text threshold | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | Optional, long text threshold | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | Optional, stream output chunk size | `5` |
|
||||
| **Fake Stream Related** | | |
|
||||
| `FAKE_STREAM_ENABLED` | Optional, whether to enable fake streaming for models or scenarios that don't support streaming | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | Optional, interval in seconds for sending heartbeat empty data during fake streaming | `5` |
|
||||
|
||||
## ⚙️ API 端点
|
||||
## ⚙️ API Endpoints
|
||||
|
||||
以下是服务提供的主要 API 端点:
|
||||
The following are the main API endpoints provided by the service:
|
||||
|
||||
### Gemini API 相关 (`(/gemini)/v1beta`)
|
||||
### Gemini API Related (`(/gemini)/v1beta`)
|
||||
|
||||
* `GET /models`: 列出可用的 Gemini 模型。
|
||||
* `POST /models/{model_name}:generateContent`: 使用指定的 Gemini 模型生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 使用指定的 Gemini 模型流式生成内容。
|
||||
* `GET /models`: List available Gemini models.
|
||||
* `POST /models/{model_name}:generateContent`: Generate content using the specified Gemini model.
|
||||
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation using the specified Gemini model.
|
||||
|
||||
### OpenAI API 相关 (`(/hf)/v1`)
|
||||
### OpenAI API Related
|
||||
|
||||
* `GET /v1/models`: 列出可用的 OpenAI 模型。
|
||||
* `POST /v1/chat/completions`: 通过 OpenAI API 进行聊天补全。
|
||||
* `POST /v1/images/generations`: 通过 OpenAI API 生成图像。
|
||||
* `POST /v1/embeddings`: 通过 OpenAI API 创建文本嵌入。
|
||||
* `GET (/hf)/v1/models`: List available models (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/chat/completions`: Perform chat completion (uses Gemini format underneath, supports streaming).
|
||||
* `POST (/hf)/v1/embeddings`: Create text embeddings (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/images/generations`: Generate images (uses Gemini format underneath).
|
||||
* `GET /openai/v1/models`: List available models (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/chat/completions`: Perform chat completion (uses OpenAI format underneath, supports streaming, can prevent truncation, and is faster).
|
||||
* `POST /openai/v1/embeddings`: Create text embeddings (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/images/generations`: Generate images (uses OpenAI format underneath).
|
||||
|
||||
## 🤝 贡献
|
||||
## 🤝 Contributing
|
||||
|
||||
欢迎提交 Pull Request 或 Issue。
|
||||
Pull Requests or Issues are welcome.
|
||||
|
||||
## 🎉 特别鸣谢
|
||||
## 🎉 Special Thanks
|
||||
|
||||
特别鸣谢以下项目和平台为本项目提供图床服务:
|
||||
Special thanks to the following projects and platforms for providing image hosting services for this project:
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) 开源项目
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) open source project
|
||||
|
||||
## 🙏 感谢贡献者
|
||||
## 🙏 Thanks to Contributors
|
||||
|
||||
感谢所有为本项目做出贡献的开发者!
|
||||
Thanks to all developers who contributed to this project!
|
||||
|
||||
[](https://github.com/snailyp/gemini-balance/graphs/contributors)
|
||||
|
||||
## Thanks to Our Supporters
|
||||
|
||||
A special shout-out to DigitalOcean for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
|
||||
[](https://m.do.co/c/b249dd7f3b4c)
|
||||
|
||||
CDN acceleration and security protection for this project are sponsored by Tencent EdgeOne.
|
||||
[](https://edgeone.ai/?from=github)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
[](https://star-history.com/#snailyp/gemini-balance&Date)
|
||||
|
||||
## 💖 友情项目
|
||||
## 💖 Friendly Projects
|
||||
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine一线:AI驱动的热点事件时间轴生成工具
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine: AI-driven hot event timeline generation tool
|
||||
|
||||
## 许可证
|
||||
## 🎁 Project Support
|
||||
|
||||
本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
If you find this project helpful, consider supporting me via [Afdian](https://afdian.com/a/snaily).
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the CC BY-NC 4.0 (Attribution-NonCommercial) license. Any form of commercial resale service is prohibited. See the LICENSE file for details.
|
||||
|
||||
280
README_ZH.md
Normal file
280
README_ZH.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# Gemini Balance - Gemini API 代理和负载均衡器
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> ⚠️ 本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
|
||||
> 本人从未在各个平台售卖服务,如有遇到售卖此服务者,那一定是倒卖狗,大家切记不要上当受骗。
|
||||
|
||||
[](https://www.python.org/)
|
||||
[](https://fastapi.tiangolo.com/)
|
||||
[](https://www.uvicorn.org/)
|
||||
[](https://t.me/+soaHax5lyI0wZDVl)
|
||||
> 交流群:https://t.me/+soaHax5lyI0wZDVl
|
||||
|
||||
## 项目简介
|
||||
|
||||
Gemini Balance 是一个基于 Python FastAPI 构建的应用程序,旨在提供 Google Gemini API 的代理和负载均衡功能。它允许您管理多个 Gemini API Key,并通过简单的配置实现 Key 的轮询、认证、模型过滤和状态监控。此外,项目还集成了图像生成和多种图床上传功能,并支持 OpenAI API 格式的代理。
|
||||
|
||||
**项目结构:**
|
||||
|
||||
```plaintext
|
||||
app/
|
||||
├── config/ # 配置管理
|
||||
├── core/ # 核心应用逻辑 (FastAPI 实例创建, 中间件等)
|
||||
├── database/ # 数据库模型和连接
|
||||
├── domain/ # 业务领域对象 (可选)
|
||||
├── exception/ # 自定义异常
|
||||
├── handler/ # 请求处理器 (可选, 或在 router 中处理)
|
||||
├── log/ # 日志配置
|
||||
├── main.py # 应用入口
|
||||
├── middleware/ # FastAPI 中间件
|
||||
├── router/ # API 路由 (Gemini, OpenAI, 状态页等)
|
||||
├── scheduler/ # 定时任务 (如 Key 状态检查)
|
||||
├── service/ # 业务逻辑服务 (聊天, Key 管理, 统计等)
|
||||
├── static/ # 静态文件 (CSS, JS)
|
||||
├── templates/ # HTML 模板 (如 Key 状态页)
|
||||
├── utils/ # 工具函数
|
||||
```
|
||||
|
||||
## ✨ 功能亮点
|
||||
|
||||
* **多 Key 负载均衡**: 支持配置多个 Gemini API Key (`API_KEYS`),自动按顺序轮询使用,提高可用性和并发能力。
|
||||
* **可视化配置即时生效**: 通过管理后台修改配置后,无需重启服务即可生效,切记要点击保存才会生效。
|
||||

|
||||
* **双协议API 兼容**: 同时支持 Gemini 和 OpenAI 格式的 CHAT API 请求转发。
|
||||
|
||||
```palintext
|
||||
openai baseurl `http://localhost:8000(/hf)/v1`
|
||||
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
|
||||
```
|
||||
|
||||
* **支持图文对话和修改图片**: `IMAGE_MODELS`配置哪个模型可以图文对话和修图的功能,实际调用的时候,用 `配置模型-image`这个模型名对话使用该功能。
|
||||

|
||||

|
||||
* **支持联网搜索**: 支持联网搜索,`SEARCH_MODELS`配置哪些模型可以联网搜索,实际调用的时候,用 `配置模型-search`这个模型名对话使用该功能
|
||||

|
||||
* **Key 状态监控**: 提供 `/keys_status` 页面(需要认证),实时查看各 Key 的状态和使用情况。
|
||||

|
||||
* **详细的日志记录**: 提供详细的错误日志,方便排查。
|
||||

|
||||

|
||||

|
||||
* **支持自定义gemini代理**: 支持自定义gemini代理,比如自行在deno或者cloudflare上搭建gemini代理
|
||||
* **openai画图接口兼容**: 将`imagen-3.0-generate-002`模型接口改造成openai画图接口,支持客户端调用。
|
||||
* **灵活的添加密钥方式**: 灵活的添加密钥方式,采用正则匹配`gemini_key`,密钥去重
|
||||

|
||||
* **兼容openai格式embeddings接口**:完美适配openai格式的`embeddings`接口,可用于本地文档向量化。
|
||||
* **流式响应优化**: 可选的流式输出优化器 (`STREAM_OPTIMIZER_ENABLED`),改善长文本流式响应的体验。
|
||||
* **失败重试与 Key 管理**: 自动处理 API 请求失败,进行重试 (`MAX_RETRIES`),并在 Key 失效次数过多时自动禁用 (`MAX_FAILURES`),定时检查恢复 (`CHECK_INTERVAL_HOURS`)。
|
||||
* **Docker 支持**: 支持AMD,ARM架构的docker部署,也可自行构建docker镜像。
|
||||
>镜像地址: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **模型列表自动维护**: 支持openai和gemini模型列表获取,与newapi自动获取模型列表完美兼容,无需手动填写。
|
||||
* **支持移除不使用的模型**: 默认提供的模型太多,很多用不上,可以通过`FILTERED_MODELS`过滤掉。
|
||||
* **代理支持**: 支持配置 HTTP/SOCKS5 代理服务器 (`PROXIES`),用于访问 Gemini API,方便在特殊网络环境下使用。支持批量添加代理。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 自行构建 Docker (推荐)
|
||||
|
||||
#### a) dockerfile构建
|
||||
|
||||
1. **构建镜像**:
|
||||
|
||||
```bash
|
||||
docker build -t gemini-balance .
|
||||
```
|
||||
|
||||
2. **运行容器**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env gemini-balance
|
||||
```
|
||||
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量。
|
||||
|
||||
> 注意:如果使用 SQLite 数据库,需要挂载数据卷以持久化数据:
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
|
||||
> ```
|
||||
> 其中 `/path/to/data` 是主机上的数据存储路径,`/app/data` 是容器内的数据目录。
|
||||
|
||||
#### b) 用现有的docker镜像部署
|
||||
|
||||
1. **拉取镜像**:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
2. **运行容器**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口 (根据需要调整)。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量 (确保 `.env` 文件存在于执行命令的目录)。
|
||||
|
||||
> 注意:如果使用 SQLite 数据库,需要挂载数据卷以持久化数据:
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
|
||||
> ```
|
||||
> 其中 `/path/to/data` 是主机上的数据存储路径,`/app/data` 是容器内的数据目录。
|
||||
|
||||
### 本地运行 (适用于开发和测试)
|
||||
|
||||
如果您想在本地直接运行源代码进行开发或测试,请按照以下步骤操作:
|
||||
|
||||
1. **确保已完成准备工作**:
|
||||
* 克隆仓库到本地。
|
||||
* 安装 Python 3.9 或更高版本。
|
||||
* 在项目根目录下创建并配置好 `.env` 文件 (参考前面的"配置环境变量"部分)。
|
||||
* 安装项目依赖:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **启动应用**:
|
||||
在项目根目录下运行以下命令:
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
* `app.main:app`: 指定 FastAPI 应用实例的位置 (`app` 模块中的 `main.py` 文件里的 `app` 对象)。
|
||||
* `--host 0.0.0.0`: 使应用可以从本地网络中的任何 IP 地址访问。
|
||||
* `--port 8000`: 指定应用监听的端口号 (您可以根据需要修改)。
|
||||
* `--reload`: 启用自动重载功能。当您修改代码时,服务会自动重启,非常适合开发环境 (生产环境请移除此选项)。
|
||||
|
||||
3. **访问应用**:
|
||||
应用启动后,您可以通过浏览器或 API 工具访问 `http://localhost:8000` (或您指定的主机和端口)。
|
||||
|
||||
### 完整配置项列表
|
||||
|
||||
| 配置项 | 说明 | 默认值 |
|
||||
| :--------------------------- | :------------------------------------------------------- | :---------------------------------------------------- |
|
||||
| **数据库配置** | | |
|
||||
| `DATABASE_TYPE` | 可选,数据库类型,支持 `mysql` 或 `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | 可选,当使用 `sqlite` 时必填,SQLite 数据库文件路径 | `default_db` |
|
||||
| `MYSQL_HOST` | 当使用 `mysql` 时必填,MySQL 数据库主机地址 | `localhost` |
|
||||
| `MYSQL_SOCKET` | 可选,MySQL 数据库 socket 地址 | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | 当使用 `mysql` 时必填,MySQL 数据库端口 | `3306` |
|
||||
| `MYSQL_USER` | 当使用 `mysql` 时必填,MySQL 数据库用户名 | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | 当使用 `mysql` 时必填,MySQL 数据库密码 | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | 当使用 `mysql` 时必填,MySQL 数据库名称 | `defaultdb` |
|
||||
| **API 相关配置** | | |
|
||||
| `API_KEYS` | 必填,Gemini API 密钥列表,用于负载均衡 | `["your-gemini-api-key-1", "your-gemini-api-key-2"]` |
|
||||
| `ALLOWED_TOKENS` | 必填,允许访问的 Token 列表 | `["your-access-token-1", "your-access-token-2"]` |
|
||||
| `AUTH_TOKEN` | 可选,超级管理员token,具有所有权限,不填默认使用 ALLOWED_TOKENS 的第一个 | `sk-123456` |
|
||||
| `TEST_MODEL` | 可选,用于测试密钥是否可用的模型名 | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | 可选,支持绘图功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | 可选,支持搜索功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `FILTERED_MODELS` | 可选,被禁用的模型列表 | `["gemini-1.0-pro-vision-latest", ...]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | 可选,是否启用代码执行工具 | `false` |
|
||||
| `SHOW_SEARCH_LINK` | 可选,是否在响应中显示搜索结果链接 | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | 可选,是否显示模型思考过程 | `true` |
|
||||
| `THINKING_MODELS` | 可选,支持思考功能的模型列表 | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | 可选,思考功能预算映射 (模型名:预算值) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | 可选,是否启用智能路由映射功能 | `false` |
|
||||
| `URL_CONTEXT_ENABLED` | 可选,是否启用URL上下文理解功能 | `false` |
|
||||
| `URL_CONTEXT_MODELS` | 可选,支持URL上下文理解功能的模型列表 | `[]` |
|
||||
| `BASE_URL` | 可选,Gemini API 基础 URL,默认无需修改 | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | 可选,允许单个key失败的次数 | `3` |
|
||||
| `MAX_RETRIES` | 可选,API 请求失败时的最大重试次数 | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | 可选,检查禁用 Key 是否恢复的时间间隔 (小时) | `1` |
|
||||
| `TIMEZONE` | 可选,应用程序使用的时区 | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | 可选,请求超时时间 (秒) | `300` |
|
||||
| `PROXIES` | 可选,代理服务器列表 (例如 `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
|
||||
| `LOG_LEVEL` | 可选,日志级别,例如 DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 可选,是否开启自动删除错误日志 | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 可选,自动删除多少天前的错误日志 (例如 1, 7, 30) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 可选,是否开启自动删除请求日志 | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 可选,自动删除多少天前的请求日志 (例如 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | 可选,安全设置 (JSON 字符串格式),用于配置内容安全阈值。示例值可能需要根据实际模型支持情况调整。 | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **TTS 相关** | | |
|
||||
| `TTS_MODEL` | 可选,TTS 模型名称 | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | 可选,TTS 语音名称 | `Zephyr` |
|
||||
| `TTS_SPEED` | 可选,TTS 语速 | `normal` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 可选,付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 可选,图片生成模型 | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | 可选,图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `SMMS_SECRET_TOKEN` | 可选,SM.MS图床的API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | 可选,[PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | 可选,[CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| 可选,CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| 可选,CloudFlare图床的上传文件夹路径 | `""` |
|
||||
| **流式优化器相关** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | 可选,是否启用流式输出优化 | `false` |
|
||||
| `STREAM_MIN_DELAY` | 可选,流式输出最小延迟 | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | 可选,流式输出最大延迟 | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD`| 可选,短文本阈值 | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | 可选,长文本阈值 | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | 可选,流式输出块大小 | `5` |
|
||||
| **伪流式 (Fake Stream) 相关** | | |
|
||||
| `FAKE_STREAM_ENABLED` | 可选,是否启用伪流式传输,用于不支持流式的模型或场景 | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | 可选,伪流式传输时发送心跳空数据的间隔秒数 | `5` |
|
||||
|
||||
## ⚙️ API 端点
|
||||
|
||||
以下是服务提供的主要 API 端点:
|
||||
|
||||
### Gemini API 相关 (`(/gemini)/v1beta`)
|
||||
|
||||
* `GET /models`: 列出可用的 Gemini 模型。
|
||||
* `POST /models/{model_name}:generateContent`: 使用指定的 Gemini 模型生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 使用指定的 Gemini 模型流式生成内容。
|
||||
|
||||
### OpenAI API 相关
|
||||
|
||||
* `GET (/hf)/v1/models`: 列出可用的模型 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/chat/completions`: 进行聊天补全 (底层用的gemini格式, 支持流式传输)。
|
||||
* `POST (/hf)/v1/embeddings`: 创建文本嵌入 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/images/generations`: 生成图像 (底层用的gemini格式)。
|
||||
* `GET /openai/v1/models`: 列出可用的模型 (底层用的openai格式)。
|
||||
* `POST /openai/v1/chat/completions`: 进行聊天补全 (底层用的openai格式, 支持流式传输, 可防止截断,速度也快)。
|
||||
* `POST /openai/v1/embeddings`: 创建文本嵌入 (底层用的openai格式)。
|
||||
* `POST /openai/v1/images/generations`: 生成图像 (底层用的openai格式)。
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
欢迎提交 Pull Request 或 Issue。
|
||||
|
||||
## 🎉 特别鸣谢
|
||||
|
||||
特别鸣谢以下项目和平台为本项目提供图床服务:
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) 开源项目
|
||||
|
||||
## 🙏 感谢贡献者
|
||||
|
||||
感谢所有为本项目做出贡献的开发者!
|
||||
|
||||
[](https://github.com/snailyp/gemini-balance/graphs/contributors)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
[](https://star-history.com/#snailyp/gemini-balance&Date)
|
||||
|
||||
## 💖 友情项目
|
||||
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine一线:AI驱动的热点事件时间轴生成工具
|
||||
|
||||
## 🎁 项目支持
|
||||
|
||||
如果你觉得这个项目对你有帮助,可以考虑通过 [爱发电](https://afdian.com/a/snaily) 支持我。
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
@@ -1,53 +1,93 @@
|
||||
"""
|
||||
应用程序配置模块
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import List, Any, Dict, Type
|
||||
from typing import Any, Dict, List, Type, get_args, get_origin
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydantic import ValidationError, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from sqlalchemy import insert, update, select
|
||||
from sqlalchemy import insert, select, update
|
||||
|
||||
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT, MAX_RETRIES
|
||||
from app.core.constants import (
|
||||
API_VERSION,
|
||||
DEFAULT_CREATE_IMAGE_MODEL,
|
||||
DEFAULT_FILTER_MODELS,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_SAFETY_SETTINGS,
|
||||
DEFAULT_STREAM_CHUNK_SIZE,
|
||||
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
DEFAULT_STREAM_MAX_DELAY,
|
||||
DEFAULT_STREAM_MIN_DELAY,
|
||||
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
DEFAULT_TIMEOUT,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
from app.log.logger import Logger
|
||||
# from app.log.logger import get_config_logger # 移除顶层导入
|
||||
# 延迟导入以避免循环依赖,仅在 sync_initial_settings 中使用
|
||||
# from app.database.connection import database
|
||||
# from app.database.models import Settings as SettingsModel
|
||||
# from app.database.services import get_all_settings # get_all_settings 可能不适合启动时调用,直接查询
|
||||
|
||||
# logger = get_config_logger() # 移除顶层初始化
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 数据库配置
|
||||
MYSQL_HOST: str
|
||||
MYSQL_PORT: int
|
||||
MYSQL_USER: str
|
||||
MYSQL_PASSWORD: str
|
||||
MYSQL_DATABASE: str
|
||||
|
||||
DATABASE_TYPE: str = "mysql" # sqlite 或 mysql
|
||||
SQLITE_DATABASE: str = "default_db"
|
||||
MYSQL_HOST: str = ""
|
||||
MYSQL_PORT: int = 3306
|
||||
MYSQL_USER: str = ""
|
||||
MYSQL_PASSWORD: str = ""
|
||||
MYSQL_DATABASE: str = ""
|
||||
MYSQL_SOCKET: str = ""
|
||||
|
||||
# 验证 MySQL 配置
|
||||
@field_validator(
|
||||
"MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE"
|
||||
)
|
||||
def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any:
|
||||
if info.data.get("DATABASE_TYPE") == "mysql":
|
||||
if v is None or v == "":
|
||||
raise ValueError(
|
||||
"MySQL configuration is required when DATABASE_TYPE is 'mysql'"
|
||||
)
|
||||
return v
|
||||
|
||||
# API相关配置
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
API_KEYS: List[str]=[]
|
||||
ALLOWED_TOKENS: List[str]=[]
|
||||
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
TEST_MODEL: str = DEFAULT_MODEL
|
||||
TIME_OUT: int = DEFAULT_TIMEOUT
|
||||
MAX_RETRIES: int = MAX_RETRIES
|
||||
|
||||
PROXIES: List[str] = []
|
||||
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
|
||||
VERTEX_API_KEYS: List[str] = []
|
||||
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
||||
|
||||
# 智能路由配置
|
||||
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
|
||||
|
||||
# 自定义 Headers
|
||||
CUSTOM_HEADERS: Dict[str, str] = {}
|
||||
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
# 是否启用网址上下文
|
||||
URL_CONTEXT_ENABLED: bool = True
|
||||
URL_CONTEXT_MODELS: List[str] = ["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
THINKING_MODELS: List[str] = [] # 新增:用于思考过程的模型列表
|
||||
THINKING_BUDGET_MAP: Dict[str, float] = {} # 新增:模型对应的预算映射
|
||||
|
||||
THINKING_MODELS: List[str] = []
|
||||
THINKING_BUDGET_MAP: Dict[str, float] = {}
|
||||
|
||||
# TTS相关配置
|
||||
TTS_MODEL: str = "gemini-2.5-flash-preview-tts"
|
||||
TTS_VOICE_NAME: str = "Zephyr"
|
||||
TTS_SPEED: str = "normal"
|
||||
|
||||
# 图像生成相关配置
|
||||
PAID_KEY: str = ""
|
||||
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
||||
@@ -56,7 +96,8 @@ class Settings(BaseSettings):
|
||||
PICGO_API_KEY: str = ""
|
||||
CLOUDFLARE_IMGBED_URL: str = ""
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_OPTIMIZER_ENABLED: bool = False
|
||||
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
|
||||
@@ -65,16 +106,30 @@ class Settings(BaseSettings):
|
||||
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
|
||||
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
|
||||
|
||||
# 假流式配置 (Fake Streaming Configuration)
|
||||
FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出
|
||||
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒)
|
||||
|
||||
# 调度器配置
|
||||
CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时
|
||||
TIMEZONE: str = "Asia/Shanghai" # 默认时区
|
||||
|
||||
# github
|
||||
CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时
|
||||
TIMEZONE: str = "Asia/Shanghai" # 默认时区
|
||||
|
||||
# github
|
||||
GITHUB_REPO_OWNER: str = "snailyp"
|
||||
GITHUB_REPO_NAME: str = "gemini-balance"
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = "INFO" # 默认日志级别
|
||||
LOG_LEVEL: str = "INFO"
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
|
||||
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
|
||||
|
||||
# Files API
|
||||
FILES_CLEANUP_ENABLED: bool = True
|
||||
FILES_CLEANUP_INTERVAL_HOURS: int = 1
|
||||
FILES_USER_ISOLATION_ENABLED: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -82,54 +137,120 @@ class Settings(BaseSettings):
|
||||
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
|
||||
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
"""尝试将数据库字符串值解析为目标 Python 类型"""
|
||||
from app.log.logger import get_config_logger # 函数内导入
|
||||
logger = get_config_logger() # 函数内初始化
|
||||
from app.log.logger import get_config_logger
|
||||
|
||||
logger = get_config_logger()
|
||||
try:
|
||||
# 处理 List[str]
|
||||
if target_type == List[str]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
return [item.strip() for item in db_value.split(',') if item.strip()]
|
||||
logger.warning(f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list.")
|
||||
return [item.strip() for item in db_value.split(',') if item.strip()]
|
||||
# 处理 Dict[str, float]
|
||||
elif target_type == Dict[str, float]:
|
||||
parsed_dict = {}
|
||||
try:
|
||||
# First attempt: standard JSON parsing
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}")
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
# Second attempt: try replacing single quotes if JSONDecodeError occurred
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}")
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
origin_type = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
|
||||
# 处理 List 类型
|
||||
if origin_type is list:
|
||||
# 处理 List[str]
|
||||
if args and args[0] == str:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif args and get_origin(args[0]) is dict:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
valid = all(
|
||||
isinstance(item, dict)
|
||||
and all(isinstance(k, str) for k in item.keys())
|
||||
and all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
else:
|
||||
logger.warning(f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}")
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict.")
|
||||
else:
|
||||
# Log other errors (ValueError, TypeError) or JSON errors without single quotes
|
||||
logger.error(f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict.")
|
||||
return parsed_dict # Return the parsed dict or an empty one if all attempts fail
|
||||
logger.warning(
|
||||
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
||||
)
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
||||
)
|
||||
return []
|
||||
# 处理 Dict 类型
|
||||
elif origin_type is dict:
|
||||
# 处理 Dict[str, str]
|
||||
if args and args == (str, str):
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): str(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict.")
|
||||
return parsed_dict
|
||||
# 处理 Dict[str, float]
|
||||
elif args and args == (str, float):
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(
|
||||
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
||||
)
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 bool
|
||||
elif target_type == bool:
|
||||
return db_value.lower() in ('true', '1', 'yes', 'on')
|
||||
return db_value.lower() in ("true", "1", "yes", "on")
|
||||
# 处理 int
|
||||
elif target_type == int:
|
||||
return int(db_value)
|
||||
@@ -140,8 +261,11 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
else:
|
||||
return db_value
|
||||
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value.")
|
||||
return db_value # 解析失败则返回原始字符串
|
||||
logger.warning(
|
||||
f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value."
|
||||
)
|
||||
return db_value # 解析失败则返回原始字符串
|
||||
|
||||
|
||||
async def sync_initial_settings():
|
||||
"""
|
||||
@@ -150,8 +274,9 @@ async def sync_initial_settings():
|
||||
2. 将数据库设置合并到内存 settings (数据库优先)。
|
||||
3. 将最终的内存 settings 同步回数据库。
|
||||
"""
|
||||
from app.log.logger import get_config_logger # 函数内导入
|
||||
logger = get_config_logger() # 函数内初始化
|
||||
from app.log.logger import get_config_logger
|
||||
|
||||
logger = get_config_logger()
|
||||
# 延迟导入以避免循环依赖和确保数据库连接已初始化
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings as SettingsModel
|
||||
@@ -164,7 +289,9 @@ async def sync_initial_settings():
|
||||
await database.connect()
|
||||
logger.info("Database connection established for initial sync.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to database for initial settings sync: {e}. Skipping sync.")
|
||||
logger.error(
|
||||
f"Failed to connect to database for initial settings sync: {e}. Skipping sync."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -173,18 +300,30 @@ async def sync_initial_settings():
|
||||
try:
|
||||
query = select(SettingsModel.key, SettingsModel.value)
|
||||
results = await database.fetch_all(query)
|
||||
db_settings_raw = [{"key": row["key"], "value": row["value"]} for row in results]
|
||||
db_settings_raw = [
|
||||
{"key": row["key"], "value": row["value"]} for row in results
|
||||
]
|
||||
logger.info(f"Fetched {len(db_settings_raw)} settings from database.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings.")
|
||||
logger.error(
|
||||
f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings."
|
||||
)
|
||||
# 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库
|
||||
|
||||
db_settings_map: Dict[str, str] = {s['key']: s['value'] for s in db_settings_raw}
|
||||
db_settings_map: Dict[str, str] = {
|
||||
s["key"]: s["value"] for s in db_settings_raw
|
||||
}
|
||||
|
||||
# 2. 将数据库设置合并到内存 settings (数据库优先)
|
||||
updated_in_memory = False
|
||||
|
||||
for key, db_value in db_settings_map.items():
|
||||
if key == "DATABASE_TYPE":
|
||||
logger.debug(
|
||||
f"Skipping update of '{key}' in memory from database. "
|
||||
"This setting is controlled by environment/dotenv."
|
||||
)
|
||||
continue
|
||||
if hasattr(settings, key):
|
||||
target_type = Settings.__annotations__.get(key)
|
||||
if target_type:
|
||||
@@ -197,35 +336,46 @@ async def sync_initial_settings():
|
||||
if parsed_db_value != memory_value:
|
||||
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
|
||||
type_match = False
|
||||
if target_type == List[str] and isinstance(parsed_db_value, list):
|
||||
type_match = True
|
||||
elif target_type == Dict[str, float] and isinstance(parsed_db_value, dict):
|
||||
type_match = True
|
||||
elif target_type not in (List[str], Dict[str, float]) and isinstance(parsed_db_value, target_type):
|
||||
origin_type = get_origin(target_type)
|
||||
if origin_type: # It's a generic type
|
||||
if isinstance(parsed_db_value, origin_type):
|
||||
type_match = True
|
||||
# It's a non-generic type, or a specific generic we want to handle
|
||||
elif isinstance(parsed_db_value, target_type):
|
||||
type_match = True
|
||||
|
||||
if type_match:
|
||||
setattr(settings, key, parsed_db_value)
|
||||
logger.info(f"Updated setting '{key}' in memory from database value ({target_type}).")
|
||||
logger.debug(
|
||||
f"Updated setting '{key}' in memory from database value ({target_type})."
|
||||
)
|
||||
updated_in_memory = True
|
||||
else:
|
||||
logger.warning(f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update.")
|
||||
logger.warning(
|
||||
f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing database setting for key '{key}': {e}")
|
||||
logger.error(
|
||||
f"Error processing database setting for key '{key}': {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Database setting '{key}' not found in Settings model definition. Ignoring.")
|
||||
|
||||
logger.warning(
|
||||
f"Database setting '{key}' not found in Settings model definition. Ignoring."
|
||||
)
|
||||
|
||||
# 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐)
|
||||
if updated_in_memory:
|
||||
try:
|
||||
# 重新加载以确保类型转换和验证
|
||||
settings = Settings(**settings.model_dump())
|
||||
logger.info("Settings object re-validated after merging database values.")
|
||||
logger.info(
|
||||
"Settings object re-validated after merging database values."
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error after merging database settings: {e}. Settings might be inconsistent.")
|
||||
|
||||
logger.error(
|
||||
f"Validation error after merging database settings: {e}. Settings might be inconsistent."
|
||||
)
|
||||
|
||||
# 3. 将最终的内存 settings 同步回数据库
|
||||
final_memory_settings = settings.model_dump()
|
||||
@@ -236,21 +386,30 @@ async def sync_initial_settings():
|
||||
existing_db_keys = set(db_settings_map.keys())
|
||||
|
||||
for key, value in final_memory_settings.items():
|
||||
if key == "DATABASE_TYPE":
|
||||
logger.debug(
|
||||
f"Skipping synchronization of '{key}' to database. "
|
||||
"This setting is controlled by environment/dotenv."
|
||||
)
|
||||
continue
|
||||
|
||||
# 序列化值为字符串或 JSON 字符串
|
||||
if isinstance(value, (list, dict)): # 处理列表和字典
|
||||
db_value = json.dumps(value, ensure_ascii=False) # 使用 ensure_ascii=False 以支持非 ASCII 字符
|
||||
if isinstance(value, (list, dict)):
|
||||
db_value = json.dumps(
|
||||
value, ensure_ascii=False
|
||||
)
|
||||
elif isinstance(value, bool):
|
||||
db_value = str(value).lower()
|
||||
elif value is None: # 处理 None 值
|
||||
db_value = "" # 或者根据需要设为 NULL 或其他标记
|
||||
elif value is None:
|
||||
db_value = ""
|
||||
else:
|
||||
db_value = str(value)
|
||||
|
||||
data = {
|
||||
'key': key,
|
||||
'value': db_value,
|
||||
'description': f"{key} configuration setting", # 默认描述
|
||||
'updated_at': now
|
||||
"key": key,
|
||||
"value": db_value,
|
||||
"description": f"{key} configuration setting",
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
if key in existing_db_keys:
|
||||
@@ -259,7 +418,7 @@ async def sync_initial_settings():
|
||||
settings_to_update.append(data)
|
||||
else:
|
||||
# 如果键不在数据库中,则插入
|
||||
data['created_at'] = now
|
||||
data["created_at"] = now
|
||||
settings_to_insert.append(data)
|
||||
|
||||
# 在事务中执行批量插入和更新
|
||||
@@ -268,51 +427,78 @@ async def sync_initial_settings():
|
||||
async with database.transaction():
|
||||
if settings_to_insert:
|
||||
# 获取现有描述以避免覆盖
|
||||
query_existing = select(SettingsModel.key, SettingsModel.description).where(SettingsModel.key.in_([s['key'] for s in settings_to_insert]))
|
||||
existing_desc = {row['key']: row['description'] for row in await database.fetch_all(query_existing)}
|
||||
query_existing = select(
|
||||
SettingsModel.key, SettingsModel.description
|
||||
).where(
|
||||
SettingsModel.key.in_(
|
||||
[s["key"] for s in settings_to_insert]
|
||||
)
|
||||
)
|
||||
existing_desc = {
|
||||
row["key"]: row["description"]
|
||||
for row in await database.fetch_all(query_existing)
|
||||
}
|
||||
for item in settings_to_insert:
|
||||
item['description'] = existing_desc.get(item['key'], item['description'])
|
||||
item["description"] = existing_desc.get(
|
||||
item["key"], item["description"]
|
||||
)
|
||||
|
||||
query_insert = insert(SettingsModel).values(settings_to_insert)
|
||||
await database.execute(query=query_insert)
|
||||
logger.info(f"Synced (inserted) {len(settings_to_insert)} settings to database.")
|
||||
logger.info(
|
||||
f"Synced (inserted) {len(settings_to_insert)} settings to database."
|
||||
)
|
||||
|
||||
if settings_to_update:
|
||||
# 获取现有描述以避免覆盖
|
||||
query_existing = select(SettingsModel.key, SettingsModel.description).where(SettingsModel.key.in_([s['key'] for s in settings_to_update]))
|
||||
existing_desc = {row['key']: row['description'] for row in await database.fetch_all(query_existing)}
|
||||
query_existing = select(
|
||||
SettingsModel.key, SettingsModel.description
|
||||
).where(
|
||||
SettingsModel.key.in_(
|
||||
[s["key"] for s in settings_to_update]
|
||||
)
|
||||
)
|
||||
existing_desc = {
|
||||
row["key"]: row["description"]
|
||||
for row in await database.fetch_all(query_existing)
|
||||
}
|
||||
|
||||
for setting_data in settings_to_update:
|
||||
setting_data['description'] = existing_desc.get(setting_data['key'], setting_data['description'])
|
||||
setting_data["description"] = existing_desc.get(
|
||||
setting_data["key"], setting_data["description"]
|
||||
)
|
||||
query_update = (
|
||||
update(SettingsModel)
|
||||
.where(SettingsModel.key == setting_data['key'])
|
||||
.where(SettingsModel.key == setting_data["key"])
|
||||
.values(
|
||||
value=setting_data['value'],
|
||||
description=setting_data['description'],
|
||||
updated_at=setting_data['updated_at']
|
||||
value=setting_data["value"],
|
||||
description=setting_data["description"],
|
||||
updated_at=setting_data["updated_at"],
|
||||
)
|
||||
)
|
||||
await database.execute(query=query_update)
|
||||
logger.info(f"Synced (updated) {len(settings_to_update)} settings to database.")
|
||||
logger.info(
|
||||
f"Synced (updated) {len(settings_to_update)} settings to database."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync settings to database during startup: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to sync settings to database during startup: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.info("No setting changes detected between memory and database during initial sync.")
|
||||
logger.info(
|
||||
"No setting changes detected between memory and database during initial sync."
|
||||
)
|
||||
|
||||
# 刷新日志等级
|
||||
Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL"))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred during initial settings sync: {e}")
|
||||
finally:
|
||||
if database.is_connected:
|
||||
try:
|
||||
# Don't disconnect if it's managed elsewhere (e.g., FastAPI lifespan)
|
||||
# await database.disconnect()
|
||||
# logger.info("Database connection closed after initial sync.")
|
||||
pass # Assume connection lifecycle is managed by the application lifespan
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting database after initial sync: {e}")
|
||||
try:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting database after initial sync: {e}")
|
||||
|
||||
logger.info("Initial settings synchronization finished.")
|
||||
|
||||
@@ -1,47 +1,32 @@
|
||||
"""
|
||||
应用程序工厂模块,负责创建和配置FastAPI应用程序实例
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.config.config import settings, sync_initial_settings
|
||||
from app.log.logger import get_application_logger
|
||||
from app.middleware.middleware import setup_middlewares
|
||||
from app.exception.exceptions import setup_exception_handlers
|
||||
from app.router.routes import setup_routers
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.core.initialization import initialize_app
|
||||
from app.database.connection import connect_to_db, disconnect_from_db
|
||||
from app.database.initialization import initialize_database
|
||||
from app.scheduler.key_checker import start_scheduler, stop_scheduler # 导入调度器函数
|
||||
from app.service.update.update_service import check_for_updates # 导入更新检查服务
|
||||
from app.exception.exceptions import setup_exception_handlers
|
||||
from app.log.logger import get_application_logger
|
||||
from app.middleware.middleware import setup_middlewares
|
||||
from app.router.routes import setup_routers
|
||||
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.update.update_service import check_for_updates
|
||||
from app.utils.helpers import get_current_version
|
||||
|
||||
logger = get_application_logger()
|
||||
|
||||
VERSION_FILE_PATH = "VERSION" # Path relative to project root
|
||||
|
||||
def _get_current_version(default_version: str = "0.0.0") -> str:
|
||||
"""Reads the current version from the VERSION file."""
|
||||
try:
|
||||
# Assuming execution from project root d:/develop/pythonProjects/gemini-balance
|
||||
with open(VERSION_FILE_PATH, 'r', encoding='utf-8') as f:
|
||||
version = f.read().strip()
|
||||
if not version:
|
||||
logger.warning(f"VERSION file ('{VERSION_FILE_PATH}') is empty. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
return version
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"VERSION file not found at '{VERSION_FILE_PATH}'. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
except IOError as e:
|
||||
logger.error(f"Error reading VERSION file ('{VERSION_FILE_PATH}'): {e}. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
STATIC_DIR = PROJECT_ROOT / "app" / "static"
|
||||
TEMPLATES_DIR = PROJECT_ROOT / "app" / "templates"
|
||||
|
||||
# 初始化模板引擎,并添加全局变量
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
# 定义一个函数来更新模板全局变量
|
||||
def update_template_globals(app: FastAPI, update_info: dict):
|
||||
# Jinja2Templates 实例没有直接更新全局变量的方法
|
||||
@@ -51,104 +36,118 @@ def update_template_globals(app: FastAPI, update_info: dict):
|
||||
logger.info(f"Update info stored in app.state: {update_info}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
应用程序生命周期管理器
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
# 启动事件
|
||||
logger.info("Application starting up...")
|
||||
try:
|
||||
# 初始化数据库
|
||||
initialize_database()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
# 连接到数据库
|
||||
await connect_to_db()
|
||||
|
||||
# 同步初始配置(DB优先,然后同步回DB)
|
||||
await sync_initial_settings()
|
||||
|
||||
# 初始化KeyManager (使用可能已从DB更新的settings)
|
||||
await get_key_manager_instance(settings.API_KEYS)
|
||||
logger.info("KeyManager initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application: {str(e)}")
|
||||
# 不重新抛出,允许应用继续运行,但记录错误
|
||||
# raise # 取消注释以在初始化失败时停止应用
|
||||
|
||||
# 检查更新 (在核心初始化之后)
|
||||
update_available, latest_version, error_message = await check_for_updates()
|
||||
update_info = {
|
||||
"update_available": update_available,
|
||||
"latest_version": latest_version,
|
||||
"error_message": error_message,
|
||||
"current_version": _get_current_version() # Read from VERSION file
|
||||
}
|
||||
# 将更新信息存储在 app.state 中
|
||||
app.state.update_info = update_info
|
||||
logger.info(f"Update check completed. Info: {update_info}")
|
||||
# --- Helper functions for lifespan ---
|
||||
async def _setup_database_and_config(app_settings):
|
||||
"""Initializes database, syncs settings, and initializes KeyManager."""
|
||||
initialize_database()
|
||||
logger.info("Database initialized successfully")
|
||||
await connect_to_db()
|
||||
await sync_initial_settings()
|
||||
await get_key_manager_instance(app_settings.API_KEYS, app_settings.VERTEX_API_KEYS)
|
||||
logger.info("Database, config sync, and KeyManager initialized successfully")
|
||||
|
||||
|
||||
# 启动调度器 (如果初始化成功)
|
||||
async def _shutdown_database():
|
||||
"""Disconnects from the database."""
|
||||
await disconnect_from_db()
|
||||
|
||||
|
||||
def _start_scheduler():
|
||||
"""Starts the background scheduler."""
|
||||
try:
|
||||
start_scheduler()
|
||||
logger.info("Scheduler started successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
|
||||
|
||||
yield # 应用程序运行期间
|
||||
|
||||
# 关闭事件
|
||||
logger.info("Application shutting down...")
|
||||
|
||||
# 停止调度器
|
||||
def _stop_scheduler():
|
||||
"""Stops the background scheduler."""
|
||||
stop_scheduler()
|
||||
logger.info("Scheduler stopped.")
|
||||
|
||||
# 断开数据库连接
|
||||
await disconnect_from_db()
|
||||
|
||||
async def _perform_update_check(app: FastAPI):
|
||||
"""Checks for updates and stores the info in app.state."""
|
||||
update_available, latest_version, error_message = await check_for_updates()
|
||||
current_version = get_current_version()
|
||||
update_info = {
|
||||
"update_available": update_available,
|
||||
"latest_version": latest_version,
|
||||
"error_message": error_message,
|
||||
"current_version": current_version,
|
||||
}
|
||||
if not hasattr(app, "state"):
|
||||
from starlette.datastructures import State
|
||||
|
||||
app.state = State()
|
||||
app.state.update_info = update_info
|
||||
logger.info(f"Update check completed. Info: {update_info}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Manages the application startup and shutdown events.
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
logger.info("Application starting up...")
|
||||
try:
|
||||
await _setup_database_and_config(settings)
|
||||
await _perform_update_check(app)
|
||||
_start_scheduler()
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(
|
||||
f"Critical error during application startup: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Application shutting down...")
|
||||
_stop_scheduler()
|
||||
await _shutdown_database()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""
|
||||
创建并配置FastAPI应用程序实例
|
||||
|
||||
|
||||
Returns:
|
||||
FastAPI: 配置好的FastAPI应用程序实例
|
||||
"""
|
||||
# 初始化应用程序
|
||||
initialize_app()
|
||||
|
||||
|
||||
# 创建FastAPI应用
|
||||
current_version = get_current_version()
|
||||
app = FastAPI(
|
||||
title="Gemini Balance API",
|
||||
description="Gemini API代理服务,支持负载均衡和密钥管理",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
version=current_version,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# 初始化 app.state (如果尚未存在)
|
||||
if not hasattr(app, "state"):
|
||||
from starlette.datastructures import State
|
||||
app.state = State()
|
||||
# 确保 update_info 即使在 lifespan 之前访问也不会出错
|
||||
app.state.update_info = {"update_available": False, "latest_version": None, "error_message": "Checking...", "current_version": _get_current_version()} # Read from VERSION file for initial state
|
||||
|
||||
app.state = State()
|
||||
app.state.update_info = {
|
||||
"update_available": False,
|
||||
"latest_version": None,
|
||||
"error_message": "Initializing...",
|
||||
"current_version": current_version,
|
||||
}
|
||||
|
||||
# 配置静态文件
|
||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
|
||||
# 配置中间件
|
||||
setup_middlewares(app)
|
||||
|
||||
|
||||
# 配置异常处理器
|
||||
setup_exception_handlers(app)
|
||||
|
||||
|
||||
# 配置路由
|
||||
setup_routers(app)
|
||||
|
||||
|
||||
return app
|
||||
|
||||
@@ -15,12 +15,12 @@ DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TOP_K = 40
|
||||
DEFAULT_FILTER_MODELS = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001",
|
||||
]
|
||||
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
|
||||
# 图像生成相关常量
|
||||
@@ -38,5 +38,75 @@ DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||
|
||||
# 正则表达式模式
|
||||
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||
IMAGE_URL_PATTERN = r"!\[(.*?)\]\((.*?)\)"
|
||||
DATA_URL_PATTERN = r"data:([^;]+);base64,(.+)"
|
||||
|
||||
# Audio/Video Settings
|
||||
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
||||
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
||||
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
|
||||
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
||||
AUDIO_FORMAT_TO_MIMETYPE = {
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"ogg": "audio/ogg",
|
||||
}
|
||||
|
||||
VIDEO_FORMAT_TO_MIMETYPE = {
|
||||
"mp4": "video/mp4",
|
||||
"mov": "video/quicktime",
|
||||
"avi": "video/x-msvideo",
|
||||
"webm": "video/webm",
|
||||
}
|
||||
|
||||
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
|
||||
DEFAULT_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
TTS_VOICE_NAMES = [
|
||||
"Zephyr",
|
||||
"Puck",
|
||||
"Charon",
|
||||
"Kore",
|
||||
"Fenrir",
|
||||
"Leda",
|
||||
"Orus",
|
||||
"Aoede",
|
||||
"Callirrhoe",
|
||||
"Autonoe",
|
||||
"Enceladus",
|
||||
"Iapetus",
|
||||
"Umbriel",
|
||||
"Algieba",
|
||||
"Despina",
|
||||
"Erinome",
|
||||
"Algenib",
|
||||
"Rasalgethi",
|
||||
"Laomedeia",
|
||||
"Achernar",
|
||||
"Alnilam",
|
||||
"Schedar",
|
||||
"Gacrux",
|
||||
"Pulcherrima",
|
||||
"Achird",
|
||||
"Zubenelgenubi",
|
||||
"Vindemiatrix",
|
||||
"Sadachbia",
|
||||
"Sadaltager",
|
||||
"Sulafat",
|
||||
]
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
应用程序初始化模块
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from app.log.logger import get_initialization_logger
|
||||
|
||||
logger = get_initialization_logger()
|
||||
|
||||
|
||||
def ensure_directories_exist(directories: List[str]) -> None:
|
||||
"""
|
||||
确保指定的目录存在,如果不存在则创建
|
||||
|
||||
Args:
|
||||
directories: 要确保存在的目录列表
|
||||
"""
|
||||
for directory in directories:
|
||||
try:
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Ensured directory exists: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create directory {directory}: {str(e)}")
|
||||
|
||||
|
||||
def initialize_app() -> None:
|
||||
"""
|
||||
初始化应用程序,确保所需的目录和文件都存在
|
||||
"""
|
||||
# 确保必要的目录存在
|
||||
required_directories = [
|
||||
"app/static/css",
|
||||
"app/static/js",
|
||||
"app/static/icons",
|
||||
"app/templates",
|
||||
]
|
||||
|
||||
ensure_directories_exist(required_directories)
|
||||
logger.info("core initialization completed")
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
数据库连接池模块
|
||||
"""
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from databases import Database
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@@ -11,7 +13,19 @@ from app.log.logger import get_database_logger
|
||||
logger = get_database_logger()
|
||||
|
||||
# 数据库URL
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
||||
if settings.DATABASE_TYPE == "sqlite":
|
||||
# 确保 data 目录存在
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
db_path = data_dir / settings.SQLITE_DATABASE
|
||||
DATABASE_URL = f"sqlite:///{db_path}"
|
||||
elif settings.DATABASE_TYPE == "mysql":
|
||||
if settings.MYSQL_SOCKET:
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
|
||||
else:
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
||||
else:
|
||||
raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")
|
||||
|
||||
# 创建数据库引擎
|
||||
# pool_pre_ping=True: 在从连接池获取连接前执行简单的 "ping" 测试,确保连接有效
|
||||
@@ -23,14 +37,16 @@ metadata = MetaData()
|
||||
# 创建基类
|
||||
Base = declarative_base(metadata=metadata)
|
||||
|
||||
# 创建数据库连接池,并配置连接池参数
|
||||
# 创建数据库连接池,并配置连接池参数,在sqlite中不使用连接池
|
||||
# min_size/max_size: 连接池的最小/最大连接数
|
||||
# pool_recycle=3600: 连接在池中允许存在的最大秒数(生命周期)。
|
||||
# 设置为 3600 秒(1小时),确保在 MySQL 默认的 wait_timeout (通常8小时) 或其他网络超时之前回收连接。
|
||||
# 如果遇到连接失效问题,可以尝试调低此值,使其小于实际的 wait_timeout 或网络超时时间。
|
||||
# databases 库会自动处理连接失效后的重连尝试。
|
||||
database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800) # Reduced recycle time to 30 mins
|
||||
|
||||
if settings.DATABASE_TYPE == "sqlite":
|
||||
database = Database(DATABASE_URL)
|
||||
else:
|
||||
database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800)
|
||||
|
||||
async def connect_to_db():
|
||||
"""
|
||||
@@ -38,7 +54,7 @@ async def connect_to_db():
|
||||
"""
|
||||
try:
|
||||
await database.connect()
|
||||
logger.info("Connected to database")
|
||||
logger.info(f"Connected to {settings.DATABASE_TYPE}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to database: {str(e)}")
|
||||
raise
|
||||
@@ -50,6 +66,6 @@ async def disconnect_from_db():
|
||||
"""
|
||||
try:
|
||||
await database.disconnect()
|
||||
logger.info("Disconnected from database")
|
||||
logger.info(f"Disconnected from {settings.DATABASE_TYPE}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disconnect from database: {str(e)}")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
数据库模型模块
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean # 添加 Boolean
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
|
||||
import enum
|
||||
|
||||
from app.database.connection import Base
|
||||
|
||||
@@ -42,20 +43,87 @@ class ErrorLog(Base):
|
||||
def __repr__(self):
|
||||
return f"<ErrorLog(id='{self.id}', gemini_key='{self.gemini_key}')>"
|
||||
|
||||
# 新增 RequestLog 模型
|
||||
|
||||
class RequestLog(Base):
|
||||
"""
|
||||
API 请求日志表
|
||||
"""
|
||||
|
||||
__tablename__ = "t_request_log"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
||||
model_name = Column(String(100), nullable=True, comment="模型名称")
|
||||
api_key = Column(String(100), nullable=True, comment="使用的API密钥") # 考虑安全性,后续可优化
|
||||
api_key = Column(String(100), nullable=True, comment="使用的API密钥")
|
||||
is_success = Column(Boolean, nullable=False, comment="请求是否成功")
|
||||
status_code = Column(Integer, nullable=True, comment="API响应状态码")
|
||||
latency_ms = Column(Integer, nullable=True, comment="请求耗时(毫秒)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
||||
|
||||
|
||||
class FileState(enum.Enum):
|
||||
"""文件状态枚举"""
|
||||
PROCESSING = "PROCESSING"
|
||||
ACTIVE = "ACTIVE"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
"""
|
||||
文件记录表,用于存储上传到 Gemini 的文件信息
|
||||
"""
|
||||
__tablename__ = "t_file_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 文件基本信息
|
||||
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
|
||||
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
|
||||
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
|
||||
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
|
||||
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
|
||||
|
||||
# 状态信息
|
||||
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
|
||||
|
||||
# 时间戳
|
||||
create_time = Column(DateTime, nullable=False, comment="创建时间")
|
||||
update_time = Column(DateTime, nullable=False, comment="更新时间")
|
||||
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
|
||||
|
||||
# API 相关
|
||||
uri = Column(String(500), nullable=False, comment="文件访问 URI")
|
||||
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
|
||||
upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)")
|
||||
|
||||
# 额外信息
|
||||
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
|
||||
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式,用于 API 响应"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"displayName": self.display_name,
|
||||
"mimeType": self.mime_type,
|
||||
"sizeBytes": str(self.size_bytes),
|
||||
"createTime": self.create_time.isoformat() + "Z",
|
||||
"updateTime": self.update_time.isoformat() + "Z",
|
||||
"expirationTime": self.expiration_time.isoformat() + "Z",
|
||||
"sha256Hash": self.sha256_hash,
|
||||
"uri": self.uri,
|
||||
"state": self.state.value if self.state else "PROCESSING"
|
||||
}
|
||||
|
||||
def is_expired(self):
|
||||
"""检查文件是否已过期"""
|
||||
# 确保比较时都是 timezone-aware
|
||||
expiration_time = self.expiration_time
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
|
||||
return datetime.datetime.now(datetime.timezone.utc) > expiration_time
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""
|
||||
数据库服务模块
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from datetime import datetime # Keep this import
|
||||
|
||||
from sqlalchemy import select, insert, update, func
|
||||
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings, ErrorLog, RequestLog # Import RequestLog
|
||||
from app.database.models import Settings, ErrorLog, RequestLog, FileRecord, FileState
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
@@ -73,7 +71,7 @@ async def update_setting(key: str, value: str, description: Optional[str] = None
|
||||
.values(
|
||||
value=value,
|
||||
description=description if description else setting["description"],
|
||||
updated_at=datetime.now() # Use datetime.now()
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
@@ -87,8 +85,8 @@ async def update_setting(key: str, value: str, description: Optional[str] = None
|
||||
key=key,
|
||||
value=value,
|
||||
description=description,
|
||||
created_at=datetime.now(), # Use datetime.now()
|
||||
updated_at=datetime.now() # Use datetime.now()
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
@@ -157,19 +155,25 @@ async def get_error_logs(
|
||||
offset: int = 0,
|
||||
key_search: Optional[str] = None,
|
||||
error_search: Optional[str] = None,
|
||||
error_code_search: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None,
|
||||
sort_by: str = 'id',
|
||||
sort_order: str = 'desc'
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取错误日志,支持搜索和日期过滤
|
||||
获取错误日志,支持搜索、日期过滤和排序
|
||||
|
||||
Args:
|
||||
limit (int): 限制数量
|
||||
offset (int): 偏移量
|
||||
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
||||
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
||||
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
||||
start_date (Optional[datetime]): 开始日期时间
|
||||
end_date (Optional[datetime]): 结束日期时间
|
||||
sort_by (str): 排序字段 (例如 'id', 'request_time')
|
||||
sort_order (str): 排序顺序 ('asc' or 'desc')
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 错误日志列表
|
||||
@@ -185,7 +189,6 @@ async def get_error_logs(
|
||||
ErrorLog.request_time
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if key_search:
|
||||
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
||||
if error_search:
|
||||
@@ -196,22 +199,33 @@ async def get_error_logs(
|
||||
if start_date:
|
||||
query = query.where(ErrorLog.request_time >= start_date)
|
||||
if end_date:
|
||||
# Use the datetime object directly for comparison
|
||||
query = query.where(ErrorLog.request_time < end_date)
|
||||
if error_code_search:
|
||||
try:
|
||||
error_code_int = int(error_code_search)
|
||||
query = query.where(ErrorLog.error_code == error_code_int)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
||||
|
||||
sort_column = getattr(ErrorLog, sort_by, ErrorLog.id)
|
||||
if sort_order.lower() == 'asc':
|
||||
query = query.order_by(asc(sort_column))
|
||||
else:
|
||||
query = query.order_by(desc(sort_column))
|
||||
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
# Apply ordering, limit, and offset
|
||||
query = query.order_by(ErrorLog.id.desc()).limit(limit).offset(offset)
|
||||
|
||||
result = await database.fetch_all(query)
|
||||
return [dict(row) for row in result]
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error logs with filters: {str(e)}") # Use exception for stack trace
|
||||
logger.exception(f"Failed to get error logs with filters: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_error_logs_count(
|
||||
key_search: Optional[str] = None,
|
||||
error_search: Optional[str] = None,
|
||||
error_code_search: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> int:
|
||||
@@ -221,6 +235,7 @@ async def get_error_logs_count(
|
||||
Args:
|
||||
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
||||
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
||||
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
||||
start_date (Optional[datetime]): 开始日期时间
|
||||
end_date (Optional[datetime]): 结束日期时间
|
||||
|
||||
@@ -230,7 +245,6 @@ async def get_error_logs_count(
|
||||
try:
|
||||
query = select(func.count()).select_from(ErrorLog)
|
||||
|
||||
# Apply the same filters as get_error_logs
|
||||
if key_search:
|
||||
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
||||
if error_search:
|
||||
@@ -241,13 +255,19 @@ async def get_error_logs_count(
|
||||
if start_date:
|
||||
query = query.where(ErrorLog.request_time >= start_date)
|
||||
if end_date:
|
||||
# Use the datetime object directly for comparison
|
||||
query = query.where(ErrorLog.request_time < end_date)
|
||||
if error_code_search:
|
||||
try:
|
||||
error_code_int = int(error_code_search)
|
||||
query = query.where(ErrorLog.error_code == error_code_int)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
||||
|
||||
|
||||
count_result = await database.fetch_one(query)
|
||||
return count_result[0] if count_result else 0
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to count error logs with filters: {str(e)}") # Use exception for stack trace
|
||||
logger.exception(f"Failed to count error logs with filters: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -273,7 +293,7 @@ async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
log_dict['request_msg'] = json.dumps(log_dict['request_msg'], ensure_ascii=False, indent=2)
|
||||
except TypeError:
|
||||
log_dict['request_msg'] = str(log_dict['request_msg']) # Fallback to string
|
||||
log_dict['request_msg'] = str(log_dict['request_msg'])
|
||||
return log_dict
|
||||
else:
|
||||
return None
|
||||
@@ -281,6 +301,93 @@ async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
"""
|
||||
根据提供的 ID 列表批量删除错误日志 (异步)。
|
||||
|
||||
Args:
|
||||
log_ids: 要删除的错误日志 ID 列表。
|
||||
|
||||
Returns:
|
||||
int: 实际删除的日志数量。
|
||||
"""
|
||||
if not log_ids:
|
||||
return 0
|
||||
try:
|
||||
# 使用 databases 执行删除
|
||||
query = delete(ErrorLog).where(ErrorLog.id.in_(log_ids))
|
||||
# execute 返回受影响的行数,但 databases 库的 execute 不直接返回 rowcount
|
||||
# 我们需要先查询是否存在,或者依赖数据库约束/触发器(如果适用)
|
||||
# 或者,我们可以执行删除并假设成功,除非抛出异常
|
||||
# 为了简单起见,我们执行删除并记录日志,不精确返回删除数量
|
||||
# 如果需要精确数量,需要先执行 SELECT COUNT(*)
|
||||
await database.execute(query)
|
||||
# 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量
|
||||
# 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试
|
||||
logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}")
|
||||
return len(log_ids) # 返回尝试删除的数量
|
||||
except Exception as e:
|
||||
# 数据库连接或执行错误
|
||||
logger.error(f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_error_log_by_id(log_id: int) -> bool:
|
||||
"""
|
||||
根据 ID 删除单个错误日志 (异步)。
|
||||
|
||||
Args:
|
||||
log_id: 要删除的错误日志 ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功删除返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
# 先检查是否存在 (可选,但更明确)
|
||||
check_query = select(ErrorLog.id).where(ErrorLog.id == log_id)
|
||||
exists = await database.fetch_one(check_query)
|
||||
|
||||
if not exists:
|
||||
logger.warning(f"Attempted to delete non-existent error log with ID: {log_id}")
|
||||
return False
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(ErrorLog).where(ErrorLog.id == log_id)
|
||||
await database.execute(delete_query)
|
||||
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def delete_all_error_logs() -> int:
|
||||
"""
|
||||
删除所有错误日志条目。
|
||||
|
||||
Returns:
|
||||
int: 被删除的错误日志数量。
|
||||
"""
|
||||
try:
|
||||
# 1. 获取删除前的总数
|
||||
count_query = select(func.count()).select_from(ErrorLog)
|
||||
total_to_delete = await database.fetch_val(count_query)
|
||||
|
||||
if total_to_delete == 0:
|
||||
logger.info("No error logs found to delete.")
|
||||
return 0
|
||||
|
||||
# 2. 执行删除操作
|
||||
delete_query = delete(ErrorLog)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Successfully deleted all {total_to_delete} error logs.")
|
||||
return total_to_delete
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# 新增函数:添加请求日志
|
||||
async def add_request_log(
|
||||
model_name: Optional[str],
|
||||
@@ -316,8 +423,268 @@ async def add_request_log(
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
await database.execute(query)
|
||||
# logger.debug(f"Added request log: key={api_key[:4]}..., success={is_success}, model={model_name}") # Use debug level
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add request log: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 文件记录相关函数 ====================
|
||||
|
||||
async def create_file_record(
|
||||
name: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
api_key: str,
|
||||
uri: str,
|
||||
create_time: datetime,
|
||||
update_time: datetime,
|
||||
expiration_time: datetime,
|
||||
state: FileState = FileState.PROCESSING,
|
||||
display_name: Optional[str] = None,
|
||||
sha256_hash: Optional[str] = None,
|
||||
upload_url: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
mime_type: MIME 类型
|
||||
size_bytes: 文件大小(字节)
|
||||
api_key: 上传时使用的 API Key
|
||||
uri: 文件访问 URI
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
expiration_time: 过期时间
|
||||
display_name: 显示名称
|
||||
sha256_hash: SHA256 哈希值
|
||||
upload_url: 临时上传 URL
|
||||
user_token: 上传用户的 token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 创建的文件记录
|
||||
"""
|
||||
try:
|
||||
query = insert(FileRecord).values(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
mime_type=mime_type,
|
||||
size_bytes=size_bytes,
|
||||
sha256_hash=sha256_hash,
|
||||
state=state,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
expiration_time=expiration_time,
|
||||
uri=uri,
|
||||
api_key=api_key,
|
||||
upload_url=upload_url,
|
||||
user_token=user_token
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
# 返回创建的记录
|
||||
return await get_file_record_by_name(name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据文件名获取文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord).where(FileRecord.name == name)
|
||||
result = await database.fetch_one(query)
|
||||
return dict(result) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file record by name {name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def update_file_record_state(
|
||||
file_name: str,
|
||||
state: FileState,
|
||||
update_time: Optional[datetime] = None,
|
||||
upload_completed: Optional[datetime] = None,
|
||||
sha256_hash: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新文件记录状态
|
||||
|
||||
Args:
|
||||
file_name: 文件名
|
||||
state: 新状态
|
||||
update_time: 更新时间
|
||||
upload_completed: 上传完成时间
|
||||
sha256_hash: SHA256 哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
values = {"state": state}
|
||||
if update_time:
|
||||
values["update_time"] = update_time
|
||||
if upload_completed:
|
||||
values["upload_completed"] = upload_completed
|
||||
if sha256_hash:
|
||||
values["sha256_hash"] = sha256_hash
|
||||
|
||||
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
|
||||
result = await database.execute(query)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated file record state for {file_name} to {state}")
|
||||
return True
|
||||
|
||||
logger.warning(f"File record not found for update: {file_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file record state: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def list_file_records(
|
||||
user_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
列出文件记录
|
||||
|
||||
Args:
|
||||
user_token: 用户 token(如果提供,只返回该用户的文件)
|
||||
api_key: API Key(如果提供,只返回使用该 key 的文件)
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记(偏移量)
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_file_records called with page_size={page_size}, page_token={page_token}")
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
if user_token:
|
||||
query = query.where(FileRecord.user_token == user_token)
|
||||
if api_key:
|
||||
query = query.where(FileRecord.api_key == api_key)
|
||||
|
||||
# 使用偏移量进行分页
|
||||
offset = 0
|
||||
if page_token:
|
||||
try:
|
||||
offset = int(page_token)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid page token: {page_token}")
|
||||
offset = 0
|
||||
|
||||
# 按ID升序排列,使用 OFFSET 和 LIMIT
|
||||
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
logger.debug(f"Query returned {len(results)} records")
|
||||
if results:
|
||||
logger.debug(f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}")
|
||||
|
||||
# 处理分页
|
||||
has_next = len(results) > page_size
|
||||
if has_next:
|
||||
results = results[:page_size]
|
||||
# 下一页的偏移量是当前偏移量加上本页返回的记录数
|
||||
next_offset = offset + page_size
|
||||
next_page_token = str(next_offset)
|
||||
logger.debug(f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}")
|
||||
else:
|
||||
next_page_token = None
|
||||
logger.debug(f"No next page, returning {len(results)} results")
|
||||
|
||||
return [dict(row) for row in results], next_page_token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def delete_file_record(name: str) -> bool:
|
||||
"""
|
||||
删除文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
query = delete(FileRecord).where(FileRecord.name == name)
|
||||
await database.execute(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file record: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_expired_file_records() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
删除已过期的文件记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 删除的记录列表
|
||||
"""
|
||||
try:
|
||||
# 先获取要删除的记录
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
expired_records = await database.fetch_all(query)
|
||||
|
||||
if not expired_records:
|
||||
return []
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Deleted {len(expired_records)} expired file records")
|
||||
return [dict(record) for record in expired_records]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete expired file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_api_key(name: str) -> Optional[str]:
|
||||
"""
|
||||
获取文件对应的 API Key
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: API Key,如果文件不存在或已过期则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord.api_key).where(
|
||||
(FileRecord.name == name) &
|
||||
(FileRecord.expiration_time > datetime.now(timezone.utc))
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
return result["api_key"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file API key: {str(e)}")
|
||||
raise
|
||||
|
||||
69
app/domain/file_models.py
Normal file
69
app/domain/file_models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Files API 相关的领域模型
|
||||
"""
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""文件上传配置"""
|
||||
mime_type: Optional[str] = Field(None, description="MIME 类型")
|
||||
display_name: Optional[str] = Field(None, description="显示名称,最多40个字符")
|
||||
|
||||
|
||||
class CreateFileRequest(BaseModel):
|
||||
"""创建文件请求(用于初始化上传)"""
|
||||
file: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
|
||||
|
||||
|
||||
class FileMetadata(BaseModel):
|
||||
"""文件元数据响应"""
|
||||
name: str = Field(..., description="文件名称,格式: files/{file_id}")
|
||||
displayName: Optional[str] = Field(None, description="显示名称")
|
||||
mimeType: str = Field(..., description="MIME 类型")
|
||||
sizeBytes: str = Field(..., description="文件大小(字节)")
|
||||
createTime: str = Field(..., description="创建时间 (RFC3339)")
|
||||
updateTime: str = Field(..., description="更新时间 (RFC3339)")
|
||||
expirationTime: str = Field(..., description="过期时间 (RFC3339)")
|
||||
sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值")
|
||||
uri: str = Field(..., description="文件访问 URI")
|
||||
state: str = Field(..., description="文件状态")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() + "Z"
|
||||
}
|
||||
|
||||
|
||||
class ListFilesRequest(BaseModel):
|
||||
"""列出文件请求参数"""
|
||||
pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小")
|
||||
pageToken: Optional[str] = Field(None, description="分页标记")
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
"""列出文件响应"""
|
||||
files: List[FileMetadata] = Field(default_factory=list, description="文件列表")
|
||||
nextPageToken: Optional[str] = Field(None, description="下一页标记")
|
||||
|
||||
|
||||
class UploadInitResponse(BaseModel):
|
||||
"""上传初始化响应(内部使用)"""
|
||||
file_metadata: FileMetadata
|
||||
upload_url: str
|
||||
|
||||
|
||||
class FileKeyMapping(BaseModel):
|
||||
"""文件与 API Key 的映射关系(内部使用)"""
|
||||
file_name: str
|
||||
api_key: str
|
||||
user_token: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
"""删除文件响应"""
|
||||
success: bool = Field(..., description="是否删除成功")
|
||||
message: Optional[str] = Field(None, description="消息")
|
||||
@@ -1,12 +1,30 @@
|
||||
from typing import List, Optional, Dict, Any, Literal, Union
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||
|
||||
|
||||
class SafetySetting(BaseModel):
|
||||
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
|
||||
threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None
|
||||
category: Optional[
|
||||
Literal[
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
]
|
||||
] = None
|
||||
threshold: Optional[
|
||||
Literal[
|
||||
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_NONE",
|
||||
"OFF",
|
||||
]
|
||||
] = None
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
@@ -22,24 +40,37 @@ class GenerationConfig(BaseModel):
|
||||
frequencyPenalty: Optional[float] = None
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
thinkingConfig: Optional[Dict[str, Any]] = None
|
||||
# TTS相关字段
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
role: str = "system"
|
||||
parts: List[Dict[str, Any]]
|
||||
role: Optional[str] = "system"
|
||||
parts: Union[List[Dict[str, Any]], Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
role: str
|
||||
role: Optional[str] = None
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiRequest(BaseModel):
|
||||
contents: List[GeminiContent] = []
|
||||
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||
safetySettings: Optional[List[SafetySetting]] = None
|
||||
generationConfig: Optional[GenerationConfig] = None
|
||||
systemInstruction: Optional[SystemInstruction] = None
|
||||
safetySettings: Optional[List[SafetySetting]] = Field(
|
||||
default=None, alias="safety_settings"
|
||||
)
|
||||
generationConfig: Optional[GenerationConfig] = Field(
|
||||
default=None, alias="generation_config"
|
||||
)
|
||||
systemInstruction: Optional[SystemInstruction] = Field(
|
||||
default=None, alias="system_instruction"
|
||||
)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class ResetSelectedKeysRequest(BaseModel):
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ImageMetadata:
|
||||
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: str | None = None):
|
||||
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
self.url = url
|
||||
self.delete_url = delete_url
|
||||
|
||||
|
||||
class UploadResponse:
|
||||
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
|
||||
self.success = success
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.data = data
|
||||
|
||||
|
||||
class ImageUploader:
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||
|
||||
@@ -9,11 +9,14 @@ class ChatRequest(BaseModel):
|
||||
model: str = DEFAULT_MODEL
|
||||
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = DEFAULT_TOP_P
|
||||
top_k: Optional[int] = DEFAULT_TOP_K
|
||||
stop: Optional[List[str]] = []
|
||||
stop: Optional[Union[List[str],str]] = None
|
||||
reasoning_effort: Optional[str] = None
|
||||
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||
tool_choice: Optional[str] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
@@ -23,10 +26,17 @@ class EmbeddingRequest(BaseModel):
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = "DALL-E-3"
|
||||
model: str = "imagen-3.0-generate-002"
|
||||
prompt: str = ""
|
||||
n: int = 1
|
||||
size: Optional[str] = "1024x1024"
|
||||
quality: Optional[str] = ""
|
||||
style: Optional[str] = ""
|
||||
quality: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
response_format: Optional[str] = "url"
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
model: str = "gemini-2.5-flash-preview-tts"
|
||||
input: str
|
||||
voice: str = "Kore"
|
||||
response_format: Optional[str] = "wav"
|
||||
|
||||
32
app/handler/error_handler.py
Normal file
32
app/handler/error_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
|
||||
@asynccontextmanager
|
||||
async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
|
||||
"""
|
||||
一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
|
||||
|
||||
Args:
|
||||
logger: 用于记录日志的 Logger 实例。
|
||||
operation_name: 操作的名称,用于日志记录和错误详情。
|
||||
success_message: 操作成功时记录的自定义消息 (可选)。
|
||||
failure_message: 操作失败时记录的自定义消息 (可选)。
|
||||
"""
|
||||
default_success_msg = f"{operation_name} request successful"
|
||||
default_failure_msg = f"{operation_name} request failed"
|
||||
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
try:
|
||||
yield
|
||||
logger.info(success_message or default_success_msg)
|
||||
except HTTPException as http_exc:
|
||||
# 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情
|
||||
logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
# 对于其他所有异常,记录错误并抛出标准的 500 错误
|
||||
logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal server error during {operation_name}"
|
||||
) from e
|
||||
@@ -1,62 +1,70 @@
|
||||
# app/services/chat/message_converter.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
import base64
|
||||
|
||||
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
|
||||
import requests
|
||||
|
||||
from app.core.constants import (
|
||||
AUDIO_FORMAT_TO_MIMETYPE,
|
||||
DATA_URL_PATTERN,
|
||||
IMAGE_URL_PATTERN,
|
||||
MAX_AUDIO_SIZE_BYTES,
|
||||
MAX_VIDEO_SIZE_BYTES,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
SUPPORTED_ROLES,
|
||||
SUPPORTED_VIDEO_FORMATS,
|
||||
VIDEO_FORMAT_TO_MIMETYPE,
|
||||
)
|
||||
from app.log.logger import get_message_converter_logger
|
||||
|
||||
logger = get_message_converter_logger()
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""消息转换器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
|
||||
def _get_mime_type_and_data(base64_string):
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据。
|
||||
|
||||
|
||||
参数:
|
||||
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
|
||||
返回:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith('data:'):
|
||||
if base64_string.startswith("data:"):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = DATA_URL_PATTERN
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
mime_type = (
|
||||
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
|
||||
def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||
if image_url.startswith("data:image"):
|
||||
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
||||
return {
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": encoded_data
|
||||
}
|
||||
}
|
||||
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
|
||||
else:
|
||||
encoded_data = _convert_image_to_base64(image_url)
|
||||
return {
|
||||
"inline_data": {
|
||||
"mime_type": "image/png",
|
||||
"data": encoded_data
|
||||
}
|
||||
}
|
||||
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
|
||||
|
||||
|
||||
def _convert_image_to_base64(url: str) -> str:
|
||||
@@ -70,7 +78,7 @@ def _convert_image_to_base64(url: str) -> str:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||
img_data = base64.b64encode(response.content).decode("utf-8")
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
@@ -94,12 +102,9 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
# 将URL对应的图片转换为base64
|
||||
try:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append({
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": base64_data
|
||||
}
|
||||
})
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
except Exception:
|
||||
# 如果转换失败,回退到文本模式
|
||||
parts.append({"text": text})
|
||||
@@ -112,42 +117,205 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
class OpenAIMessageConverter(MessageConverter):
|
||||
"""OpenAI消息格式转换器"""
|
||||
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
def _validate_media_data(
|
||||
self, format: str, data: str, supported_formats: List[str], max_size: int
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Validates format and size of Base64 media data."""
|
||||
if format.lower() not in supported_formats:
|
||||
logger.error(
|
||||
f"Unsupported media format: {format}. Supported: {supported_formats}"
|
||||
)
|
||||
raise ValueError(f"Unsupported media format: {format}")
|
||||
|
||||
try:
|
||||
decoded_data = base64.b64decode(data, validate=True)
|
||||
if len(decoded_data) > max_size:
|
||||
logger.error(
|
||||
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
|
||||
)
|
||||
return data
|
||||
except base64.binascii.Error as e:
|
||||
logger.error(f"Invalid Base64 data provided: {e}")
|
||||
raise ValueError("Invalid Base64 data")
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating media data: {e}")
|
||||
raise
|
||||
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction_parts = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
|
||||
parts = []
|
||||
# 特别处理最后一个assistant的消息,按\n\n分割
|
||||
if "content" in msg and isinstance(msg["content"], str) and msg["content"] and role == "assistant" and idx == len(messages) - 2:
|
||||
# 按\n\n分割消息
|
||||
content_parts = msg["content"].split("\n\n")
|
||||
for part in content_parts:
|
||||
if not part.strip(): # 跳过空内容
|
||||
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
for content_item in msg["content"]:
|
||||
if not isinstance(content_item, dict):
|
||||
logger.warning(
|
||||
f"Skipping unexpected content item format: {type(content_item)}"
|
||||
)
|
||||
continue
|
||||
# 处理可能包含图片的文本
|
||||
parts.extend(_process_text_with_image(part))
|
||||
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
|
||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||
|
||||
content_type = content_item.get("type")
|
||||
|
||||
if content_type == "text" and content_item.get("text"):
|
||||
parts.append({"text": content_item["text"]})
|
||||
elif content_type == "image_url" and content_item.get(
|
||||
"image_url", {}
|
||||
).get("url"):
|
||||
try:
|
||||
parts.append(
|
||||
_convert_image(content_item["image_url"]["url"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
|
||||
)
|
||||
parts.append(
|
||||
{
|
||||
"text": f"[Error processing image: {content_item['image_url']['url']}]"
|
||||
}
|
||||
)
|
||||
elif content_type == "input_audio" and content_item.get(
|
||||
"input_audio"
|
||||
):
|
||||
audio_info = content_item["input_audio"]
|
||||
audio_data = audio_info.get("data")
|
||||
audio_format = audio_info.get("format", "").lower()
|
||||
|
||||
if not audio_data or not audio_format:
|
||||
logger.warning(
|
||||
"Skipping audio part due to missing data or format."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
validated_data = self._validate_media_data(
|
||||
audio_format,
|
||||
audio_data,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
MAX_AUDIO_SIZE_BYTES,
|
||||
)
|
||||
|
||||
# Get MIME type
|
||||
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
|
||||
if not mime_type:
|
||||
# Should not happen if format validation passed, but double-check
|
||||
logger.error(
|
||||
f"Could not find MIME type for supported format: {audio_format}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Internal error: MIME type mapping missing for {audio_format}"
|
||||
)
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": mime_type,
|
||||
"data": validated_data, # Use the validated Base64 data
|
||||
}
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Successfully added audio part (format: {audio_format})"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Skipping audio part due to validation error: {e}"
|
||||
)
|
||||
parts.append({"text": f"[Error processing audio: {e}]"})
|
||||
except Exception:
|
||||
logger.exception("Unexpected error processing audio part.")
|
||||
parts.append(
|
||||
{"text": "[Unexpected error processing audio]"}
|
||||
)
|
||||
|
||||
elif content_type == "input_video" and content_item.get(
|
||||
"input_video"
|
||||
):
|
||||
video_info = content_item["input_video"]
|
||||
video_data = video_info.get("data")
|
||||
video_format = video_info.get("format", "").lower()
|
||||
|
||||
if not video_data or not video_format:
|
||||
logger.warning(
|
||||
"Skipping video part due to missing data or format."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
validated_data = self._validate_media_data(
|
||||
video_format,
|
||||
video_data,
|
||||
SUPPORTED_VIDEO_FORMATS,
|
||||
MAX_VIDEO_SIZE_BYTES,
|
||||
)
|
||||
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Internal error: MIME type mapping missing for {video_format}"
|
||||
)
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": mime_type,
|
||||
"data": validated_data,
|
||||
}
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Successfully added video part (format: {video_format})"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Skipping video part due to validation error: {e}"
|
||||
)
|
||||
parts.append({"text": f"[Error processing video: {e}]"})
|
||||
except Exception:
|
||||
logger.exception("Unexpected error processing video part.")
|
||||
parts.append(
|
||||
{"text": "[Unexpected error processing video]"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Log unrecognized but present types
|
||||
if content_type:
|
||||
logger.warning(
|
||||
f"Unsupported content type or missing data in structured content: {content_type}"
|
||||
)
|
||||
|
||||
elif (
|
||||
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
||||
):
|
||||
parts.extend(_process_text_with_image(msg["content"]))
|
||||
elif "content" in msg and isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict):
|
||||
if content["type"] == "text" and content["text"]:
|
||||
parts.append({"text": content["text"]})
|
||||
elif content["type"] == "image_url":
|
||||
parts.append(_convert_image(content["image_url"]["url"]))
|
||||
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
# Keep existing tool call processing
|
||||
for tool_call in msg["tool_calls"]:
|
||||
function_call = tool_call.get("function",{})
|
||||
function_call["args"] = json.loads(function_call.get("arguments","{}"))
|
||||
del function_call["arguments"]
|
||||
function_call = tool_call.get("function", {})
|
||||
# Sanitize arguments loading
|
||||
arguments_str = function_call.get("arguments", "{}")
|
||||
try:
|
||||
function_call["args"] = json.loads(arguments_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments: {arguments_str}"
|
||||
)
|
||||
function_call["args"] = {}
|
||||
if "arguments" in function_call:
|
||||
if "arguments" in function_call:
|
||||
del function_call["arguments"]
|
||||
|
||||
parts.append({"functionCall": function_call})
|
||||
|
||||
|
||||
if role not in SUPPORTED_ROLES:
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
@@ -159,7 +327,14 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
role = "model"
|
||||
if parts:
|
||||
if role == "system":
|
||||
system_instruction_parts.extend(parts)
|
||||
text_only_parts = [p for p in parts if "text" in p]
|
||||
if len(text_only_parts) != len(parts):
|
||||
logger.warning(
|
||||
"Non-text parts found in system message; discarding them."
|
||||
)
|
||||
if text_only_parts:
|
||||
system_instruction_parts.extend(text_only_parts)
|
||||
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
@@ -171,4 +346,4 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
"parts": system_instruction_parts,
|
||||
}
|
||||
)
|
||||
return converted_messages, system_instruction
|
||||
return converted_messages, system_instruction
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
# app/services/chat/response_handler.py
|
||||
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
from app.log.logger import get_openai_logger
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
"""响应处理器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -27,32 +31,44 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
self.thinking_first = True
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if stream:
|
||||
return _handle_gemini_stream_response(response, model, stream)
|
||||
return _handle_gemini_normal_response(response, model, stream)
|
||||
|
||||
|
||||
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
|
||||
if not text and not tool_calls:
|
||||
def _handle_openai_stream_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
response, model, stream=True, gemini_format=False
|
||||
)
|
||||
if not text and not tool_calls and not reasoning_content:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "role": "assistant"}
|
||||
delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
|
||||
return {
|
||||
template_chunk = {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
|
||||
}
|
||||
if usage_metadata:
|
||||
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
|
||||
return template_chunk
|
||||
|
||||
|
||||
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
|
||||
def _handle_openai_normal_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
response, model, stream=False, gemini_format=False
|
||||
)
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
@@ -61,11 +77,16 @@ def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
|
||||
}
|
||||
|
||||
|
||||
@@ -78,81 +99,94 @@ class OpenAIResponseHandler(ResponseHandler):
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
usage_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason)
|
||||
return _handle_openai_normal_response(response, model, finish_reason)
|
||||
|
||||
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
|
||||
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
|
||||
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
|
||||
|
||||
def handle_image_chat_response(
|
||||
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
||||
):
|
||||
if stream:
|
||||
return _handle_openai_stream_image_response(image_str,model,finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str,model,finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return _handle_openai_stream_image_response(image_str, model, finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str, model, finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(
|
||||
image_str: str, model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
def _handle_openai_normal_image_response(
|
||||
image_str: str, model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": image_str
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": image_str},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
|
||||
text, tool_calls = "", []
|
||||
def _extract_result(
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
|
||||
text, reasoning_content, tool_calls, thought = "", "", [], None
|
||||
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
if not parts:
|
||||
return "", []
|
||||
logger.warning("No parts found in stream response")
|
||||
return "", None, [], None
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
if "thought" in parts[0]:
|
||||
if not gemini_format and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content = text
|
||||
text = ""
|
||||
thought = parts[0].get("thought")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = _format_code_block(parts[0]["executableCode"])
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = _format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
text = _format_execution_result(parts[0]["executableCodeResult"])
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
text = _format_execution_result(parts[0]["codeExecutionResult"])
|
||||
elif "inlineData" in parts[0]:
|
||||
text = _extract_image_data(parts[0])
|
||||
else:
|
||||
@@ -162,66 +196,82 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
else:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
if "thinking" in model:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = (
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = candidate["content"]["parts"][1]["text"]
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
text = ""
|
||||
if "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
text, reasoning_content = "", ""
|
||||
|
||||
# 使用安全的访问方式
|
||||
content = candidate.get("content", {})
|
||||
|
||||
if content and isinstance(content, dict):
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts:
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
if "thought" in part and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content += part["text"]
|
||||
else:
|
||||
text += part["text"]
|
||||
if "thought" in part and thought is None:
|
||||
thought = part.get("thought")
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
|
||||
else:
|
||||
logger.warning(f"No parts found in content for model: {model}")
|
||||
else:
|
||||
logger.error(f"Invalid content structure for model: {model}")
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
||||
|
||||
# 安全地获取 parts 用于工具调用提取
|
||||
parts = candidate.get("content", {}).get("parts", [])
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
logger.warning(f"No candidates found in response for model: {model}")
|
||||
text = "暂无返回"
|
||||
return text, tool_calls
|
||||
|
||||
return text, reasoning_content, tool_calls, thought
|
||||
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
image_uploader = None
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
base64_data = part["inlineData"]["data"]
|
||||
#将base64_data转成bytes数组
|
||||
# 将base64_data转成bytes数组
|
||||
bytes_data = base64.b64decode(base64_data)
|
||||
upload_response = image_uploader.upload(bytes_data,filename)
|
||||
upload_response = image_uploader.upload(bytes_data, filename)
|
||||
if upload_response.success:
|
||||
text = f"\n\n\n\n"
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
||||
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def _extract_tool_calls(
|
||||
parts: List[Dict[str, Any]], gemini_format: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""提取工具调用信息"""
|
||||
if not parts or not isinstance(parts, list):
|
||||
return []
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
|
||||
tool_calls = list()
|
||||
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
@@ -230,7 +280,7 @@ def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> Lis
|
||||
item = part.get("functionCall", {})
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
|
||||
if gemini_format:
|
||||
tool_calls.append(part)
|
||||
else:
|
||||
@@ -250,22 +300,38 @@ def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> Lis
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
def _handle_gemini_stream_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
part = {"text": text}
|
||||
if thought is not None:
|
||||
part["thought"] = thought
|
||||
content = {"parts": [part], "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
def _handle_gemini_normal_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
parts = []
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
parts = tool_calls
|
||||
else:
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
if thought is not None:
|
||||
parts.append({"text": reasoning_content,"thought": thought})
|
||||
part = {"text": text}
|
||||
parts.append(part)
|
||||
content = {"parts": parts, "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
@@ -279,10 +345,10 @@ def _format_code_block(code_data: dict) -> str:
|
||||
|
||||
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# app/services/chat/retry_handler.py
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from app.core.constants import MAX_RETRIES
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_retry_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -13,8 +12,7 @@ logger = get_retry_logger()
|
||||
class RetryHandler:
|
||||
"""重试处理装饰器"""
|
||||
|
||||
def __init__(self, max_retries: int = MAX_RETRIES, key_arg: str = "api_key"):
|
||||
self.max_retries = max_retries
|
||||
def __init__(self, key_arg: str = "api_key"):
|
||||
self.key_arg = key_arg
|
||||
|
||||
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
||||
@@ -22,14 +20,14 @@ class RetryHandler:
|
||||
async def wrapper(*args, **kwargs) -> T:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
for attempt in range(settings.MAX_RETRIES):
|
||||
retries = attempt + 1
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(
|
||||
f"API call failed with error: {str(e)}. Attempt {retries} of {self.max_retries}"
|
||||
f"API call failed with error: {str(e)}. Attempt {retries} of {settings.MAX_RETRIES}"
|
||||
)
|
||||
|
||||
# 从函数参数中获取 key_manager
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# app/services/chat/stream_optimizer.py
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
@@ -107,15 +106,11 @@ class StreamOptimizer:
|
||||
|
||||
# 计算智能延迟时间
|
||||
delay = self.calculate_delay(len(text))
|
||||
# if self.logger:
|
||||
# self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
|
||||
|
||||
# 根据文本长度决定输出方式
|
||||
if len(text) >= self.long_text_threshold:
|
||||
# 长文本:分块输出
|
||||
chunks = self.split_text_into_chunks(text)
|
||||
# if self.logger:
|
||||
# self.logger.info(f"Long text: splitting into {len(chunks)} chunks")
|
||||
for chunk_text in chunks:
|
||||
chunk_response = create_response_chunk(chunk_text)
|
||||
yield format_chunk(chunk_response)
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
from typing import Dict, Optional
|
||||
import platform
|
||||
|
||||
# ANSI转义序列颜色代码
|
||||
COLORS = {
|
||||
'DEBUG': '\033[34m', # 蓝色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[1;31m' # 红色加粗
|
||||
"DEBUG": "\033[34m", # 蓝色
|
||||
"INFO": "\033[32m", # 绿色
|
||||
"WARNING": "\033[33m", # 黄色
|
||||
"ERROR": "\033[31m", # 红色
|
||||
"CRITICAL": "\033[1;31m", # 红色加粗
|
||||
}
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == 'Windows':
|
||||
if platform.system() == "Windows":
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
@@ -27,15 +27,17 @@ class ColoredFormatter(logging.Formatter):
|
||||
|
||||
def format(self, record):
|
||||
# 获取对应级别的颜色代码
|
||||
color = COLORS.get(record.levelname, '')
|
||||
color = COLORS.get(record.levelname, "")
|
||||
# 添加颜色代码和重置代码
|
||||
record.levelname = f"{color}{record.levelname}\033[0m"
|
||||
# 创建包含文件名和行号的固定宽度字符串
|
||||
record.fileloc = f"[{record.filename}:{record.lineno}]"
|
||||
return super().format(record)
|
||||
|
||||
|
||||
# 日志格式
|
||||
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
||||
FORMATTER = ColoredFormatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||||
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
|
||||
)
|
||||
|
||||
# 日志级别映射
|
||||
@@ -55,9 +57,7 @@ class Logger:
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@staticmethod
|
||||
def setup_logger(
|
||||
name: str
|
||||
) -> logging.Logger:
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
设置并获取logger
|
||||
:param name: logger名称
|
||||
@@ -65,6 +65,7 @@ class Logger:
|
||||
"""
|
||||
# 导入 settings 对象
|
||||
from app.config.config import settings
|
||||
|
||||
# 从全局配置获取日志级别
|
||||
log_level_str = settings.LOG_LEVEL.lower()
|
||||
level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
||||
@@ -97,7 +98,6 @@ class Logger:
|
||||
"""
|
||||
return Logger._loggers.get(name)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_log_levels(log_level: str):
|
||||
"""
|
||||
@@ -113,8 +113,6 @@ class Logger:
|
||||
# 可选:记录级别变更日志,但注意避免在日志模块内部产生过多日志
|
||||
# print(f"Updated log level for logger '{logger_name}' to {log_level_str.upper()}")
|
||||
updated_count += 1
|
||||
# if updated_count > 0:
|
||||
# print(f"Updated log level for {updated_count} loggers to {log_level_str.upper()}.")
|
||||
|
||||
|
||||
# 预定义的loggers
|
||||
@@ -207,4 +205,33 @@ def get_update_logger():
|
||||
|
||||
|
||||
def get_scheduler_routes():
|
||||
return Logger.setup_logger("scheduler_routes")
|
||||
return Logger.setup_logger("scheduler_routes")
|
||||
|
||||
|
||||
def get_message_converter_logger():
|
||||
return Logger.setup_logger("message_converter")
|
||||
|
||||
|
||||
def get_api_client_logger():
|
||||
return Logger.setup_logger("api_client")
|
||||
|
||||
|
||||
def get_openai_compatible_logger():
|
||||
return Logger.setup_logger("openai_compatible")
|
||||
|
||||
|
||||
def get_error_log_logger():
|
||||
return Logger.setup_logger("error_log")
|
||||
|
||||
|
||||
def get_request_log_logger():
|
||||
return Logger.setup_logger("request_log")
|
||||
|
||||
|
||||
def get_files_logger():
|
||||
return Logger.setup_logger("files")
|
||||
|
||||
|
||||
def get_vertex_express_logger():
|
||||
return Logger.setup_logger("vertex_express")
|
||||
|
||||
|
||||
13
app/main.py
13
app/main.py
@@ -1,18 +1,15 @@
|
||||
"""
|
||||
应用程序入口模块
|
||||
"""
|
||||
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 在导入应用程序配置之前加载 .env 文件到环境变量
|
||||
load_dotenv()
|
||||
|
||||
from app.core.application import create_app
|
||||
from app.log.logger import get_main_logger
|
||||
|
||||
# 创建应用程序实例
|
||||
app = create_app()
|
||||
|
||||
# 配置日志
|
||||
logger = get_main_logger()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_main_logger()
|
||||
logger.info("Starting application server...")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||
from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_middleware_logger
|
||||
@@ -30,6 +31,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
and not request.url.path.startswith(f"/{API_VERSION}")
|
||||
and not request.url.path.startswith("/health")
|
||||
and not request.url.path.startswith("/hf")
|
||||
and not request.url.path.startswith("/openai")
|
||||
and not request.url.path.startswith("/api/version/check")
|
||||
and not request.url.path.startswith("/vertex-express")
|
||||
and not request.url.path.startswith("/upload")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
@@ -49,6 +54,9 @@ def setup_middlewares(app: FastAPI) -> None:
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
# 添加智能路由中间件(必须在认证中间件之前)
|
||||
app.add_middleware(SmartRoutingMiddleware)
|
||||
|
||||
# 添加认证中间件
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@@ -58,7 +66,7 @@ def setup_middlewares(app: FastAPI) -> None:
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=[
|
||||
"GET",
|
||||
@@ -66,8 +74,8 @@ def setup_middlewares(app: FastAPI) -> None:
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # 明确指定允许的HTTP方法
|
||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||
expose_headers=["*"], # 允许前端访问的响应头
|
||||
max_age=600, # 预检请求缓存时间(秒)
|
||||
],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
max_age=600,
|
||||
)
|
||||
|
||||
210
app/middleware/smart_routing_middleware.py
Normal file
210
app/middleware/smart_routing_middleware.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_main_logger
|
||||
import re
|
||||
|
||||
logger = get_main_logger()
|
||||
|
||||
class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
# 简化的路由规则 - 直接根据检测结果路由
|
||||
pass
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not settings.URL_NORMALIZATION_ENABLED:
|
||||
return await call_next(request)
|
||||
logger.debug(f"request: {request}")
|
||||
original_path = str(request.url.path)
|
||||
method = request.method
|
||||
|
||||
# 尝试修复URL
|
||||
fixed_path, fix_info = self.fix_request_url(original_path, method, request)
|
||||
|
||||
if fixed_path != original_path:
|
||||
logger.info(f"URL fixed: {method} {original_path} → {fixed_path}")
|
||||
if fix_info:
|
||||
logger.debug(f"Fix details: {fix_info}")
|
||||
|
||||
# 重写请求路径
|
||||
request.scope["path"] = fixed_path
|
||||
request.scope["raw_path"] = fixed_path.encode()
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""简化的URL修复逻辑"""
|
||||
|
||||
# 首先检查是否已经是正确的格式,如果是则不处理
|
||||
if self.is_already_correct_format(path):
|
||||
return path, None
|
||||
|
||||
# 1. 最高优先级:包含generateContent → Gemini格式
|
||||
if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
|
||||
return self.fix_gemini_by_operation(path, method, request)
|
||||
|
||||
# 2. 第二优先级:包含/openai/ → OpenAI格式
|
||||
if "/openai/" in path.lower():
|
||||
return self.fix_openai_by_operation(path, method)
|
||||
|
||||
# 3. 第三优先级:包含/v1/ → v1格式
|
||||
if "/v1/" in path.lower():
|
||||
return self.fix_v1_by_operation(path, method)
|
||||
|
||||
# 4. 第四优先级:包含/chat/completions → chat功能
|
||||
if "/chat/completions" in path.lower():
|
||||
return "/v1/chat/completions", {"type": "v1_chat"}
|
||||
|
||||
# 5. 默认:原样传递
|
||||
return path, None
|
||||
|
||||
def is_already_correct_format(self, path: str) -> bool:
|
||||
"""检查是否已经是正确的API格式"""
|
||||
# 检查是否已经是正确的端点格式
|
||||
correct_patterns = [
|
||||
r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生
|
||||
r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀
|
||||
r"^/v1beta/models$", # Gemini模型列表
|
||||
r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表
|
||||
r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式
|
||||
r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式
|
||||
r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式
|
||||
r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式
|
||||
r"^/vertex-express/v1beta/models$", # Vertex Express模型列表
|
||||
r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式
|
||||
]
|
||||
|
||||
for pattern in correct_patterns:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fix_gemini_by_operation(
|
||||
self, path: str, method: str, request: Request
|
||||
) -> tuple:
|
||||
"""根据Gemini操作修复,考虑端点偏好"""
|
||||
if method == "GET":
|
||||
return "/v1beta/models", {
|
||||
"role": "gemini_models",
|
||||
}
|
||||
|
||||
# 提取模型名称
|
||||
try:
|
||||
model_name = self.extract_model_name(path, request)
|
||||
except ValueError:
|
||||
# 无法提取模型名称,返回原路径不做处理
|
||||
return path, None
|
||||
|
||||
# 检测是否为流式请求
|
||||
is_stream = self.detect_stream_request(path, request)
|
||||
|
||||
# 检查是否有vertex-express偏好
|
||||
if "/vertex-express/" in path.lower():
|
||||
if is_stream:
|
||||
target_url = (
|
||||
f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
|
||||
)
|
||||
else:
|
||||
target_url = (
|
||||
f"/vertex-express/v1beta/models/{model_name}:generateContent"
|
||||
)
|
||||
|
||||
fix_info = {
|
||||
"rule": (
|
||||
"vertex_express_generate"
|
||||
if not is_stream
|
||||
else "vertex_express_stream"
|
||||
),
|
||||
"preference": "vertex_express_format",
|
||||
"is_stream": is_stream,
|
||||
"model": model_name,
|
||||
}
|
||||
else:
|
||||
# 标准Gemini端点
|
||||
if is_stream:
|
||||
target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
|
||||
else:
|
||||
target_url = f"/v1beta/models/{model_name}:generateContent"
|
||||
|
||||
fix_info = {
|
||||
"rule": "gemini_generate" if not is_stream else "gemini_stream",
|
||||
"preference": "gemini_format",
|
||||
"is_stream": is_stream,
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
return target_url, fix_info
|
||||
|
||||
def fix_openai_by_operation(self, path: str, method: str) -> tuple:
|
||||
"""根据操作类型修复OpenAI格式"""
|
||||
if method == "POST":
|
||||
if "chat" in path.lower() or "completion" in path.lower():
|
||||
return "/openai/v1/chat/completions", {"type": "openai_chat"}
|
||||
elif "embedding" in path.lower():
|
||||
return "/openai/v1/embeddings", {"type": "openai_embeddings"}
|
||||
elif "image" in path.lower():
|
||||
return "/openai/v1/images/generations", {"type": "openai_images"}
|
||||
elif "audio" in path.lower():
|
||||
return "/openai/v1/audio/speech", {"type": "openai_audio"}
|
||||
elif method == "GET":
|
||||
if "model" in path.lower():
|
||||
return "/openai/v1/models", {"type": "openai_models"}
|
||||
|
||||
return path, None
|
||||
|
||||
def fix_v1_by_operation(self, path: str, method: str) -> tuple:
|
||||
"""根据操作类型修复v1格式"""
|
||||
if method == "POST":
|
||||
if "chat" in path.lower() or "completion" in path.lower():
|
||||
return "/v1/chat/completions", {"type": "v1_chat"}
|
||||
elif "embedding" in path.lower():
|
||||
return "/v1/embeddings", {"type": "v1_embeddings"}
|
||||
elif "image" in path.lower():
|
||||
return "/v1/images/generations", {"type": "v1_images"}
|
||||
elif "audio" in path.lower():
|
||||
return "/v1/audio/speech", {"type": "v1_audio"}
|
||||
elif method == "GET":
|
||||
if "model" in path.lower():
|
||||
return "/v1/models", {"type": "v1_models"}
|
||||
|
||||
return path, None
|
||||
|
||||
def detect_stream_request(self, path: str, request: Request) -> bool:
|
||||
"""检测是否为流式请求"""
|
||||
# 1. 路径中包含stream关键词
|
||||
if "stream" in path.lower():
|
||||
return True
|
||||
|
||||
# 2. 查询参数
|
||||
if request.query_params.get("stream") == "true":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_model_name(self, path: str, request: Request) -> str:
|
||||
"""从请求中提取模型名称,用于构建Gemini API URL"""
|
||||
# 1. 从请求体中提取
|
||||
try:
|
||||
if hasattr(request, "_body") and request._body:
|
||||
import json
|
||||
|
||||
body = json.loads(request._body.decode())
|
||||
if "model" in body and body["model"]:
|
||||
return body["model"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. 从查询参数中提取
|
||||
model_param = request.query_params.get("model")
|
||||
if model_param:
|
||||
return model_param
|
||||
|
||||
# 3. 从路径中提取(用于已包含模型名称的路径)
|
||||
match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# 4. 如果无法提取模型名称,抛出异常
|
||||
raise ValueError("Unable to extract model name from request")
|
||||
@@ -1,15 +1,17 @@
|
||||
"""
|
||||
配置路由模块
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_config_routes_logger, Logger # 导入 Logger 类
|
||||
from app.log.logger import Logger, get_config_routes_logger
|
||||
from app.service.config.config_service import ConfigService
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
logger = get_config_routes_logger()
|
||||
@@ -34,10 +36,10 @@ async def update_config(config_data: Dict[str, Any], request: Request):
|
||||
result = await ConfigService.update_config(config_data)
|
||||
# 配置更新成功后,立即更新所有 logger 的级别
|
||||
Logger.update_log_levels(config_data["LOG_LEVEL"])
|
||||
logger.info("Log levels updated after configuration change.") # 添加日志记录
|
||||
logger.info("Log levels updated after configuration change.")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating config or log levels: {e}", exc_info=True) # 记录详细错误
|
||||
logger.error(f"Error updating config or log levels: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@@ -51,3 +53,81 @@ async def reset_config(request: Request):
|
||||
return await ConfigService.reset_config()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
class DeleteKeysRequest(BaseModel):
|
||||
keys: List[str] = Field(..., description="List of API keys to delete")
|
||||
|
||||
|
||||
@router.delete("/keys/{key_to_delete}", response_model=Dict[str, Any])
|
||||
async def delete_single_key(key_to_delete: str, request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized attempt to delete key: {key_to_delete}")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
try:
|
||||
logger.info(f"Attempting to delete key: {key_to_delete}")
|
||||
result = await ConfigService.delete_key(key_to_delete)
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=(
|
||||
404 if "not found" in result.get("message", "").lower() else 400
|
||||
),
|
||||
detail=result.get("message"),
|
||||
)
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting key '{key_to_delete}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/keys/delete-selected", response_model=Dict[str, Any])
|
||||
async def delete_selected_keys_route(
|
||||
delete_request: DeleteKeysRequest, request: Request
|
||||
):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized attempt to bulk delete keys")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
if not delete_request.keys:
|
||||
logger.warning("Attempt to bulk delete keys with an empty list.")
|
||||
raise HTTPException(status_code=400, detail="No keys provided for deletion.")
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to bulk delete {len(delete_request.keys)} keys.")
|
||||
result = await ConfigService.delete_selected_keys(delete_request.keys)
|
||||
if not result.get("success") and result.get("deleted_count", 0) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=result.get("message", "Failed to delete keys.")
|
||||
)
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error bulk deleting keys: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error bulk deleting keys: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ui/models")
|
||||
async def get_ui_models(request: Request):
|
||||
auth_token_cookie = request.cookies.get("auth_token")
|
||||
if not auth_token_cookie or not verify_auth_token(auth_token_cookie):
|
||||
logger.warning("Unauthorized access attempt to /api/config/ui/models")
|
||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
models = await ConfigService.fetch_ui_models()
|
||||
return models
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in /ui/models endpoint: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching UI models: {str(e)}",
|
||||
)
|
||||
|
||||
233
app/router/error_log_routes.py
Normal file
233
app/router/error_log_routes.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
日志路由模块
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_log_routes_logger
|
||||
from app.service.error_log import error_log_service
|
||||
|
||||
router = APIRouter(prefix="/api/logs", tags=["logs"])
|
||||
|
||||
logger = get_log_routes_logger()
|
||||
|
||||
|
||||
class ErrorLogListItem(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class ErrorLogListResponse(BaseModel):
|
||||
logs: List[ErrorLogListItem]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/errors", response_model=ErrorLogListResponse)
|
||||
async def get_error_logs_api(
|
||||
request: Request,
|
||||
limit: int = Query(10, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
key_search: Optional[str] = Query(
|
||||
None, description="Search term for Gemini key (partial match)"
|
||||
),
|
||||
error_search: Optional[str] = Query(
|
||||
None, description="Search term for error type or log message"
|
||||
),
|
||||
error_code_search: Optional[str] = Query(
|
||||
None, description="Search term for error code"
|
||||
),
|
||||
start_date: Optional[datetime] = Query(
|
||||
None, description="Start datetime for filtering"
|
||||
),
|
||||
end_date: Optional[datetime] = Query(
|
||||
None, description="End datetime for filtering"
|
||||
),
|
||||
sort_by: str = Query(
|
||||
"id", description="Field to sort by (e.g., 'id', 'request_time')"
|
||||
),
|
||||
sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
|
||||
):
|
||||
"""
|
||||
获取错误日志列表 (返回错误码),支持过滤和排序
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
key_search: 密钥搜索
|
||||
error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定)
|
||||
error_code_search: 错误码搜索
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
sort_by: 排序字段
|
||||
sort_order: 排序顺序
|
||||
|
||||
Returns:
|
||||
ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count.
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to error logs list")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
result = await error_log_service.process_get_error_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
logs_data = result["logs"]
|
||||
total_count = result["total"]
|
||||
|
||||
validated_logs = [ErrorLogListItem(**log) for log in logs_data]
|
||||
return ErrorLogListResponse(logs=validated_logs, total=total_count)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error logs list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get error logs list: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class ErrorLogDetailResponse(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_log: Optional[str] = None
|
||||
request_msg: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
|
||||
@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
|
||||
async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
"""
|
||||
根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(
|
||||
f"Unauthorized access attempt to error log details for ID: {log_id}"
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
log_details = await error_log_service.process_get_error_log_details(
|
||||
log_id=log_id
|
||||
)
|
||||
if not log_details:
|
||||
raise HTTPException(status_code=404, detail="Error log not found")
|
||||
|
||||
return ErrorLogDetailResponse(**log_details)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get error log details: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_logs_bulk_api(
|
||||
request: Request, payload: Dict[str, List[int]] = Body(...)
|
||||
):
|
||||
"""
|
||||
批量删除错误日志 (异步)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to bulk delete error logs")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
log_ids = payload.get("ids")
|
||||
if not log_ids:
|
||||
raise HTTPException(status_code=400, detail="No log IDs provided for deletion.")
|
||||
|
||||
try:
|
||||
deleted_count = await error_log_service.process_delete_error_logs_by_ids(
|
||||
log_ids
|
||||
)
|
||||
# 注意:异步函数返回的是尝试删除的数量,可能不是精确值
|
||||
logger.info(
|
||||
f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}"
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error bulk deleting error logs with IDs {log_ids}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error during bulk deletion"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/errors/all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_all_error_logs_api(request: Request):
|
||||
"""
|
||||
删除所有错误日志 (异步)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to delete all error logs")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deleted_count = await error_log_service.process_delete_all_error_logs()
|
||||
logger.info(f"Successfully deleted all {deleted_count} error logs.")
|
||||
# No body needed for 204 response
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error deleting all error logs: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error during deletion of all logs"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
"""
|
||||
删除单个错误日志 (异步)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
success = await error_log_service.process_delete_error_log_by_id(log_id)
|
||||
if not success:
|
||||
# 服务层现在在未找到时返回 False,我们在这里转换为 404
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Error log with ID {log_id} not found"
|
||||
)
|
||||
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error during deletion"
|
||||
)
|
||||
295
app/router/files_routes.py
Normal file
295
app/router/files_routes.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Files API 路由
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request, Query, Depends, Header, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.file_models import (
|
||||
FileMetadata,
|
||||
ListFilesResponse,
|
||||
DeleteFileResponse
|
||||
)
|
||||
from app.log.logger import get_files_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.service.files.file_upload_handler import get_upload_handler
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
router = APIRouter()
|
||||
security_service = SecurityService()
|
||||
|
||||
|
||||
@router.post("/upload/v1beta/files")
|
||||
async def upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传"""
|
||||
logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}")
|
||||
|
||||
# 檢查是否是實際的上傳請求(有 upload_id)
|
||||
if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]:
|
||||
logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.")
|
||||
return await handle_upload(
|
||||
upload_path="v1beta/files",
|
||||
request=request,
|
||||
key=request.query_params.get("key"),
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 获取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 构建请求主机 URL
|
||||
request_host = f"{request.url.scheme}://{request.url.netloc}"
|
||||
logger.info(f"Request host: {request_host}")
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"x-goog-upload-protocol": x_goog_upload_protocol or "resumable",
|
||||
"x-goog-upload-command": x_goog_upload_command or "start",
|
||||
}
|
||||
|
||||
if x_goog_upload_header_content_length:
|
||||
headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length
|
||||
if x_goog_upload_header_content_type:
|
||||
headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type
|
||||
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
response_data, response_headers = await files_service.initialize_upload(
|
||||
headers=headers,
|
||||
body=body,
|
||||
user_token=user_token,
|
||||
request_host=request_host # 傳遞請求主機
|
||||
)
|
||||
|
||||
logger.info(f"Upload initialization response: {response_data}")
|
||||
logger.info(f"Upload initialization response headers: {response_headers}")
|
||||
|
||||
logger.info(f"Upload initialization response headers: {response_data}")
|
||||
# 返回响应
|
||||
return JSONResponse(
|
||||
content=response_data,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload initialization failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload initialization: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files")
|
||||
async def list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件"""
|
||||
logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token(如果启用用户隔离)
|
||||
user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.list_files(
|
||||
page_size=page_size,
|
||||
page_token=page_token,
|
||||
user_token=user_token
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"List files failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in list files: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files/{file_id:path}")
|
||||
async def get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息"""
|
||||
logger.debug(f"Get file request: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.get_file(f"files/{file_id}", user_token)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Get file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/v1beta/files/{file_id:path}")
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件"""
|
||||
logger.info(f"Delete file: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
success = await files_service.delete_file(f"files/{file_id}", user_token)
|
||||
|
||||
return DeleteFileResponse(
|
||||
success=success,
|
||||
message="File deleted successfully" if success else "Failed to delete file"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Delete file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in delete file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 处理上传请求的通配符路由
|
||||
@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"])
|
||||
async def handle_upload(
|
||||
upload_path: str,
|
||||
request: Request,
|
||||
key: Optional[str] = Query(None), # 從查詢參數獲取 key
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
):
|
||||
"""处理文件上传请求"""
|
||||
try:
|
||||
logger.info(f"Handling upload request: {request.method} {upload_path}, key={key}")
|
||||
|
||||
# 從查詢參數獲取 upload_id
|
||||
upload_id = request.query_params.get("upload_id")
|
||||
if not upload_id:
|
||||
raise HTTPException(status_code=400, detail="Missing upload_id")
|
||||
|
||||
# 從 session 獲取真實的 API key
|
||||
files_service = await get_files_service()
|
||||
session_info = await files_service.get_upload_session(upload_id)
|
||||
if not session_info:
|
||||
logger.error(f"No session found for upload_id: {upload_id}")
|
||||
raise HTTPException(status_code=404, detail="Upload session not found")
|
||||
|
||||
real_api_key = session_info["api_key"]
|
||||
original_upload_url = session_info["upload_url"]
|
||||
|
||||
# 使用真實的 API key 構建完整的 Google 上傳 URL
|
||||
# 保留原始 URL 的所有參數,但使用真實的 API key
|
||||
upload_url = original_upload_url
|
||||
logger.info(f"Using real API key for upload: {real_api_key[:8]}...{real_api_key[-4:]}")
|
||||
|
||||
# 代理上传请求
|
||||
upload_handler = get_upload_handler()
|
||||
return await upload_handler.proxy_upload_request(
|
||||
request=request,
|
||||
upload_url=upload_url,
|
||||
files_service=files_service
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload handling failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload handling: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 为兼容性添加 /gemini 前缀的路由
|
||||
@router.post("/gemini/upload/v1beta/files")
|
||||
async def gemini_upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传(Gemini 前缀)"""
|
||||
return await upload_file_init(
|
||||
request,
|
||||
auth_token,
|
||||
x_goog_upload_protocol,
|
||||
x_goog_upload_command,
|
||||
x_goog_upload_header_content_length,
|
||||
x_goog_upload_header_content_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files")
|
||||
async def gemini_list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件(Gemini 前缀)"""
|
||||
return await list_files(page_size, page_token, auth_token)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息(Gemini 前缀)"""
|
||||
return await get_file(file_id, auth_token)
|
||||
|
||||
|
||||
@router.delete("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件(Gemini 前缀)"""
|
||||
return await delete_file(file_id, auth_token)
|
||||
@@ -1,23 +1,23 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
import asyncio # 导入 asyncio
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest # 添加导入
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
|
||||
# 路由设置
|
||||
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||||
logger = get_gemini_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
|
||||
@@ -43,67 +43,60 @@ async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""获取可用的Gemini模型列表"""
|
||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
logger.info("Handling Gemini models list request")
|
||||
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_json = model_service.get_gemini_models(api_key)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||
|
||||
# 添加搜索模型
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
model = model_mapping.get(name)
|
||||
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-search"
|
||||
display_name = f'{item.get("displayName")} For Search'
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加图像生成模型
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
model = model_mapping.get(name)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-image"
|
||||
display_name = f'{item.get("displayName")} For Image'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加思考模型的非思考版本
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
model = model_mapping.get(name)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-non-thinking"
|
||||
display_name = f'{item.get("displayName")} Non Thinking'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
return models_json
|
||||
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
|
||||
logger.info("Gemini models list request successful")
|
||||
return models_json
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
@router_v1beta.post("/models/{model_name}:generateContent")
|
||||
@RetryHandler(max_retries=settings.MAX_RETRIES, key_arg="api_key")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
@@ -112,30 +105,57 @@ async def generate_content(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""非流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
# 检测是否为原生Gemini TTS请求
|
||||
is_native_tts = False
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
is_native_tts = True
|
||||
logger.info("Detected native Gemini TTS request")
|
||||
logger.info(f"TTS responseModalities: {response_modalities}")
|
||||
logger.info(f"TTS speechConfig: {speech_config}")
|
||||
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
# 所有原生TTS请求都使用TTS增强服务
|
||||
if is_native_tts:
|
||||
try:
|
||||
logger.info("Using native TTS enhanced service")
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"Native TTS processing failed, falling back to standard service: {e}")
|
||||
|
||||
# 使用标准服务处理所有其他请求(非TTS)
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||||
@RetryHandler(max_retries=settings.MAX_RETRIES, key_arg="api_key")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
@@ -144,25 +164,52 @@ async def stream_generate_content(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Streaming request failed") from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:countTokens")
|
||||
@router_v1beta.post("/models/{model_name}:countTokens")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def count_tokens(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini token 计数请求。"""
|
||||
operation_name = "gemini_count_tokens"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"):
|
||||
logger.info(f"Handling Gemini token count request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await chat_service.count_tokens(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
@@ -211,7 +258,7 @@ async def reset_selected_key_fail_counts(
|
||||
"""批量重置选定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
||||
keys_to_reset = request.keys
|
||||
key_type = request.key_type # 获取类型用于日志记录和响应消息
|
||||
key_type = request.key_type
|
||||
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
||||
|
||||
if not keys_to_reset:
|
||||
@@ -227,38 +274,31 @@ async def reset_selected_key_fail_counts(
|
||||
if result:
|
||||
reset_count += 1
|
||||
else:
|
||||
# 记录未找到的密钥,但不视为致命错误
|
||||
logger.warning(f"Key not found during selective reset: {key}")
|
||||
except Exception as key_error:
|
||||
# 记录单个密钥重置时的错误
|
||||
logger.error(f"Error resetting key {key}: {str(key_error)}")
|
||||
errors.append(f"Key {key}: {str(key_error)}")
|
||||
|
||||
if errors:
|
||||
# 如果有错误,报告部分成功或完全失败
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
# 确定最终状态码和成功标志
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500 # 207 Multi-Status if partially successful, 500 if completely failed
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse({
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count
|
||||
}, status_code=status_code)
|
||||
|
||||
# 完全成功的情况
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count
|
||||
})
|
||||
except Exception as e:
|
||||
# 捕获循环外的意外错误
|
||||
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
|
||||
@router.post("/reset-fail-count/{api_key}")
|
||||
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""重置指定Gemini API密钥的失败计数"""
|
||||
@@ -274,6 +314,7 @@ async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(g
|
||||
logger.error(f"Failed to reset key failure count: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
@router.post("/verify-key/{api_key}")
|
||||
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""验证Gemini API密钥的有效性"""
|
||||
@@ -281,14 +322,14 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
logger.info("Verifying API key validity")
|
||||
|
||||
try:
|
||||
# 使用generate_content接口测试key的有效性
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=[{"text": "hi"}]
|
||||
parts=[{"text": "hi"}],
|
||||
)
|
||||
]
|
||||
],
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
|
||||
)
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
@@ -298,11 +339,12 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
)
|
||||
|
||||
if response:
|
||||
return JSONResponse({"status": "valid"})
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return JSONResponse({"status": "valid"})
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
|
||||
# 验证出现异常时增加失败计数
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
@@ -325,76 +367,72 @@ async def verify_selected_keys(
|
||||
if not keys_to_verify:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
||||
|
||||
valid_count = 0
|
||||
invalid_count = 0
|
||||
verification_errors = {} # 存储验证过程中的错误
|
||||
successful_keys = []
|
||||
failed_keys = {}
|
||||
|
||||
async def _verify_single_key(api_key: str):
|
||||
"""内部函数,用于验证单个密钥并处理异常"""
|
||||
nonlocal valid_count, invalid_count # 允许修改外部计数器
|
||||
nonlocal successful_keys, failed_keys
|
||||
try:
|
||||
# 重用单密钥验证逻辑的核心部分
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])]
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
|
||||
)
|
||||
# 注意:这里直接调用 chat_service.generate_content,不依赖于 key_manager 获取密钥
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
)
|
||||
# 如果上面没有抛出异常,则认为密钥有效
|
||||
valid_count += 1
|
||||
successful_keys.append(api_key)
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return api_key, "valid", None
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.warning(f"Key verification failed for {api_key}: {error_message}")
|
||||
# 验证失败时增加失败计数 (使用与 /verify-key 一致的逻辑)
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
|
||||
else:
|
||||
# 如果密钥不在计数中(可能刚添加或从未失败),初始化为1
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
|
||||
invalid_count += 1
|
||||
failed_keys[api_key] = error_message
|
||||
return api_key, "invalid", error_message
|
||||
|
||||
# 并发执行所有密钥的验证
|
||||
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True) # return_exceptions=True 捕获任务本身的异常
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理并发执行的结果
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 捕获 asyncio.gather 可能遇到的异常(例如任务被取消)
|
||||
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
||||
# 可以选择如何处理这种任务级别的错误,这里我们简单记录
|
||||
# 也可以将其计入 invalid_count 或单独记录
|
||||
elif result:
|
||||
key, status, error = result
|
||||
if status == "invalid" and error:
|
||||
verification_errors[key] = error # 记录具体的验证错误信息
|
||||
if not isinstance(result, Exception) and result:
|
||||
key, status, error = result
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(f"Task execution error during bulk verification: {result}")
|
||||
|
||||
valid_count = len(successful_keys)
|
||||
invalid_count = len(failed_keys)
|
||||
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
||||
|
||||
# 根据是否有错误决定最终消息和状态
|
||||
if verification_errors or valid_count + invalid_count != len(keys_to_verify): # 检查是否有错误或任务异常
|
||||
error_summary = "; ".join([f"{k}: {v}" for k, v in verification_errors.items()])
|
||||
message = f"批量验证完成,但出现问题。有效: {valid_count}, 无效: {invalid_count}。错误详情: {error_summary or '任务执行异常'}"
|
||||
return JSONResponse({
|
||||
"success": False, # 标记为失败,因为有错误
|
||||
"message": message,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count,
|
||||
"errors": verification_errors
|
||||
}, status_code=207) # 207 Multi-Status 表示部分成功/失败
|
||||
else:
|
||||
# 完全成功
|
||||
if failed_keys:
|
||||
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"批量验证成功完成。有效: {valid_count}, 无效: {invalid_count}",
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count
|
||||
})
|
||||
else:
|
||||
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": {},
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": 0
|
||||
})
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
日志路由模块
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, HTTPException, Request, Query, Path
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_log_routes_logger
|
||||
# 假设这些服务函数已更新或添加
|
||||
from app.database.services import get_error_logs, get_error_logs_count, get_error_log_details
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter(prefix="/api/logs", tags=["logs"])
|
||||
|
||||
logger = get_log_routes_logger()
|
||||
|
||||
|
||||
# Define a response model that includes the total count for pagination
|
||||
# 用于列表响应的模型,假设 get_error_logs 返回包含 error_code 的字典
|
||||
class ErrorLogListItem(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_code: Optional[int] = None # 列表显示错误码 (应为整数)
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
class ErrorLogListResponse(BaseModel):
|
||||
logs: List[ErrorLogListItem] # 使用定义的模型列表
|
||||
total: int
|
||||
|
||||
@router.get("/errors", response_model=ErrorLogListResponse)
|
||||
async def get_error_logs_api(
|
||||
request: Request,
|
||||
limit: int = Query(10, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
key_search: Optional[str] = Query(None, description="Search term for Gemini key (partial match)"),
|
||||
error_search: Optional[str] = Query(None, description="Search term for error type or log message"), # 数据库查询需处理
|
||||
start_date: Optional[datetime] = Query(None, description="Start datetime for filtering"),
|
||||
end_date: Optional[datetime] = Query(None, description="End datetime for filtering")
|
||||
):
|
||||
"""
|
||||
获取错误日志列表 (返回错误码)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
key_search: 密钥搜索
|
||||
error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定)
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count.
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to error logs list")
|
||||
# API 返回 401 更合适
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
# 假设 get_error_logs 现在返回包含 error_code 的字典列表
|
||||
# 并且可以接受 include_error_code 参数 (如果需要显式指定)
|
||||
logs_data = await get_error_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
key_search=key_search,
|
||||
error_search=error_search, # 数据库查询需要处理这个
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
# include_error_code=True # 如果需要显式传递
|
||||
)
|
||||
# Fetch total count with the same search parameters
|
||||
total_count = await get_error_logs_count(
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 验证并转换数据以匹配 Pydantic 模型
|
||||
validated_logs = [ErrorLogListItem(**log) for log in logs_data]
|
||||
return ErrorLogListResponse(logs=validated_logs, total=total_count)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error logs list: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get error logs list: {str(e)}")
|
||||
|
||||
|
||||
# 新增:获取错误日志详情的路由
|
||||
class ErrorLogDetailResponse(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_log: Optional[str] = None # 详情接口返回完整的 error_log
|
||||
request_msg: Optional[str] = None # 详情接口返回 request_msg
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
|
||||
async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
"""
|
||||
根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to error log details for ID: {log_id}")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
# 假设存在一个函数 get_error_log_details(log_id) 来获取完整信息
|
||||
log_details = await get_error_log_details(log_id=log_id)
|
||||
if not log_details:
|
||||
raise HTTPException(status_code=404, detail="Error log not found")
|
||||
|
||||
# 假设 get_error_log_details 返回一个字典或兼容 Pydantic 的对象
|
||||
return ErrorLogDetailResponse(**log_details)
|
||||
except HTTPException as http_exc:
|
||||
# Re-raise HTTPException (like 404)
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get error log details: {str(e)}")
|
||||
113
app/router/openai_compatiable_routes.py
Normal file
113
app/router/openai_compatiable_routes.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_next_working_key_wrapper(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
|
||||
async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取OpenAI聊天服务实例"""
|
||||
return OpenAICompatiableService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/openai/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""获取可用模型列表。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await openai_service.get_models(api_key)
|
||||
|
||||
|
||||
@router.post("/openai/v1/chat/completions")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key()
|
||||
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
if is_image_chat:
|
||||
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
||||
return response
|
||||
else:
|
||||
response = await openai_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/openai/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
request.model = settings.CREATE_IMAGE_MODEL
|
||||
return await openai_service.generate_images(request)
|
||||
|
||||
|
||||
@router.post("/openai/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理文本嵌入请求。"""
|
||||
operation_name = "embedding"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await openai_service.create_embeddings(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
@@ -7,23 +7,26 @@ from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.tts.tts_service import TTSService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
embedding_service = EmbeddingService()
|
||||
image_create_service = ImageCreateService()
|
||||
tts_service = TTSService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
@@ -41,62 +44,63 @@ async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_mana
|
||||
return OpenAIChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
async def get_tts_service():
|
||||
"""获取TTS服务实例"""
|
||||
return tts_service
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
logger.info("-" * 50 + "list_models" + "-" * 50)
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
return model_service.get_gemini_openai_models(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching models list"
|
||||
) from e
|
||||
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await model_service.get_gemini_openai_models(api_key)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/hf/v1/chat/completions")
|
||||
@RetryHandler(max_retries=settings.MAX_RETRIES, key_arg="api_key")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager), # 保留 key_manager 用于获取 paid_key
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
||||
):
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
api_key = await key_manager.get_paid_key()
|
||||
logger.info("-" * 50 + "chat_completion" + "-" * 50)
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key()
|
||||
|
||||
if not model_service.check_model_support(request.model):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
try:
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
response = await chat_service.create_image_chat_completion(request, api_key)
|
||||
if not await model_service.check_model_support(request.model):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
if is_image_chat:
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
else:
|
||||
response = await chat_service.create_chat_completion(request, api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
logger.info("Chat completion request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@@ -105,18 +109,12 @@ async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info("-" * 50 + "generate_image" + "-" * 50)
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
|
||||
try:
|
||||
"""处理 OpenAI 图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
response = image_create_service.generate_images(request)
|
||||
logger.info("Image generation request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation request failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Image generation request failed"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@@ -126,19 +124,16 @@ async def embedding(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
logger.info("-" * 50 + "embedding" + "-" * 50)
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
"""处理 OpenAI 文本嵌入请求。"""
|
||||
operation_name = "embedding"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
logger.info("Embedding request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
||||
|
||||
|
||||
@router.get("/v1/keys/list")
|
||||
@@ -147,10 +142,10 @@ async def get_keys_list(
|
||||
_=Depends(security_service.verify_auth_token),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取有效和无效的API key列表"""
|
||||
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
||||
logger.info("Handling keys list request")
|
||||
try:
|
||||
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
|
||||
operation_name = "get_keys_list"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling keys list request")
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -160,8 +155,21 @@ async def get_keys_list(
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching keys list"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/v1/audio/speech")
|
||||
@router.post("/hf/v1/audio/speech")
|
||||
async def text_to_speech(
|
||||
request: TTSRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
):
|
||||
"""处理 OpenAI TTS 请求。"""
|
||||
operation_name = "text_to_speech"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling TTS request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
audio_data = await tts_service.create_tts(request, api_key)
|
||||
return Response(content=audio_data, media_type="audio/wav")
|
||||
|
||||
@@ -8,13 +8,12 @@ from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import gemini_routes, openai_routes, config_routes, log_routes, scheduler_routes, stats_routes # 新增导入 stats_routes
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes, files_routes
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats_service import StatsService
|
||||
from app.service.stats.stats_service import StatsService
|
||||
|
||||
logger = get_routes_logger()
|
||||
|
||||
# 配置Jinja2模板
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
@@ -25,21 +24,22 @@ def setup_routers(app: FastAPI) -> None:
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
# 包含API路由
|
||||
app.include_router(openai_routes.router)
|
||||
app.include_router(gemini_routes.router)
|
||||
app.include_router(gemini_routes.router_v1beta)
|
||||
app.include_router(config_routes.router)
|
||||
app.include_router(log_routes.router)
|
||||
app.include_router(scheduler_routes.router) # 新增包含 scheduler 路由
|
||||
app.include_router(stats_routes.router) # 包含 stats API 路由
|
||||
app.include_router(error_log_routes.router)
|
||||
app.include_router(scheduler_routes.router)
|
||||
app.include_router(stats_routes.router)
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
app.include_router(vertex_express_routes.router)
|
||||
app.include_router(files_routes.router)
|
||||
|
||||
# 添加页面路由
|
||||
setup_page_routes(app)
|
||||
|
||||
# 添加健康检查路由
|
||||
setup_health_routes(app)
|
||||
setup_api_stats_routes(app) # Add API stats routes
|
||||
setup_api_stats_routes(app)
|
||||
|
||||
|
||||
def setup_page_routes(app: FastAPI) -> None:
|
||||
@@ -104,16 +104,14 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"total_keys": total_keys, # Renamed for clarity
|
||||
"valid_key_count": valid_key_count, # Added count
|
||||
"invalid_key_count": invalid_key_count, # Added count
|
||||
"api_stats": api_stats, # <-- Pass stats to template
|
||||
"total_keys": total_keys,
|
||||
"valid_key_count": valid_key_count,
|
||||
"invalid_key_count": invalid_key_count,
|
||||
"api_stats": api_stats,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving keys status or API stats: {str(e)}")
|
||||
# Optionally, render template with error or default stats
|
||||
# For now, re-raise to show error page
|
||||
raise
|
||||
|
||||
@app.get("/config", response_class=HTMLResponse)
|
||||
@@ -173,16 +171,13 @@ def setup_api_stats_routes(app: FastAPI) -> None:
|
||||
async def api_stats_details(request: Request, period: str):
|
||||
"""获取指定时间段内的 API 调用详情"""
|
||||
try:
|
||||
# 验证认证
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to API stats details")
|
||||
# Returning JSON error instead of redirect for API endpoint
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
logger.info(f"Fetching API call details for period: {period}")
|
||||
# Use the service instance here as well
|
||||
stats_service = StatsService() # Create an instance
|
||||
stats_service = StatsService()
|
||||
details = await stats_service.get_api_call_details(period)
|
||||
return details
|
||||
except ValueError as e:
|
||||
|
||||
@@ -2,22 +2,20 @@
|
||||
定时任务控制路由模块
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, status # 移除 Depends, 添加 Request
|
||||
from fastapi import APIRouter, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.security import verify_auth_token # 导入 verify_auth_token
|
||||
from app.scheduler.key_checker import start_scheduler, stop_scheduler
|
||||
from app.log.logger import get_scheduler_routes # 使用路由日志记录器
|
||||
from app.core.security import verify_auth_token
|
||||
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
||||
from app.log.logger import get_scheduler_routes
|
||||
|
||||
logger = get_scheduler_routes()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/scheduler",
|
||||
tags=["Scheduler"]
|
||||
# 移除全局依赖
|
||||
)
|
||||
|
||||
# 认证检查的辅助函数
|
||||
async def verify_token(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
@@ -29,14 +27,12 @@ async def verify_token(request: Request):
|
||||
)
|
||||
|
||||
@router.post("/start", summary="启动定时任务")
|
||||
async def start_scheduler_endpoint(request: Request): # 添加 request 参数
|
||||
async def start_scheduler_endpoint(request: Request):
|
||||
"""Start the background scheduler task"""
|
||||
"""
|
||||
await verify_token(request) # 在函数开始处进行认证检查
|
||||
"""
|
||||
await verify_token(request)
|
||||
try:
|
||||
logger.info("Received request to start scheduler.")
|
||||
start_scheduler() # 调用 key_checker 中的函数
|
||||
start_scheduler()
|
||||
return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting scheduler: {str(e)}", exc_info=True)
|
||||
@@ -46,14 +42,12 @@ async def start_scheduler_endpoint(request: Request): # 添加 request 参数
|
||||
)
|
||||
|
||||
@router.post("/stop", summary="停止定时任务")
|
||||
async def stop_scheduler_endpoint(request: Request): # 添加 request 参数
|
||||
async def stop_scheduler_endpoint(request: Request):
|
||||
"""Stop the background scheduler task"""
|
||||
"""
|
||||
await verify_token(request) # 在函数开始处进行认证检查
|
||||
"""
|
||||
await verify_token(request)
|
||||
try:
|
||||
logger.info("Received request to stop scheduler.")
|
||||
stop_scheduler() # 调用 key_checker 中的函数
|
||||
stop_scheduler()
|
||||
return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from starlette import status
|
||||
from app.core.security import verify_auth_token
|
||||
from app.service.stats_service import StatsService
|
||||
from app.log.logger import get_stats_logger # 使用路由日志记录器
|
||||
from app.service.stats.stats_service import StatsService
|
||||
from app.log.logger import get_stats_logger
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
|
||||
# 认证检查的辅助函数
|
||||
async def verify_token(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
@@ -21,7 +20,7 @@ async def verify_token(request: Request):
|
||||
router = APIRouter(
|
||||
prefix="/api",
|
||||
tags=["stats"],
|
||||
dependencies=[Depends(verify_token)] # Assuming API routes need authentication
|
||||
dependencies=[Depends(verify_token)]
|
||||
)
|
||||
|
||||
stats_service = StatsService()
|
||||
@@ -46,14 +45,10 @@ async def get_key_usage_details(key: str):
|
||||
try:
|
||||
usage_details = await stats_service.get_key_usage_details_last_24h(key)
|
||||
if usage_details is None:
|
||||
# Handle case where key might be valid but has no recent usage,
|
||||
# or if the service layer explicitly returns None for other reasons.
|
||||
# Returning an empty dict is usually fine for the frontend.
|
||||
return {}
|
||||
return usage_details
|
||||
except Exception as e:
|
||||
# Log the exception details here if needed
|
||||
print(f"Error fetching key usage details for key {key[:4]}...: {e}")
|
||||
logger.error(f"Error fetching key usage details for key {key[:4]}...: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取密钥使用详情时出错: {e}"
|
||||
|
||||
37
app/router/version_routes.py
Normal file
37
app/router/version_routes.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
from app.service.update.update_service import check_for_updates
|
||||
from app.utils.helpers import get_current_version
|
||||
from app.log.logger import get_update_logger
|
||||
|
||||
router = APIRouter(prefix="/api/version", tags=["Version"])
|
||||
logger = get_update_logger()
|
||||
|
||||
class VersionInfo(BaseModel):
|
||||
current_version: str = Field(..., description="当前应用程序版本")
|
||||
latest_version: Optional[str] = Field(None, description="可用的最新版本")
|
||||
update_available: bool = Field(False, description="是否有可用更新")
|
||||
error_message: Optional[str] = Field(None, description="检查更新时发生的错误信息")
|
||||
|
||||
@router.get("/check", response_model=VersionInfo, summary="检查应用程序更新")
|
||||
async def get_version_info():
|
||||
"""
|
||||
检查当前应用程序版本与最新的 GitHub release 版本。
|
||||
"""
|
||||
try:
|
||||
current_version = get_current_version()
|
||||
update_available, latest_version, error_message = await check_for_updates()
|
||||
|
||||
logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'")
|
||||
|
||||
return VersionInfo(
|
||||
current_version=current_version,
|
||||
latest_version=latest_version,
|
||||
update_available=update_available,
|
||||
error_message=error_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in /api/version/check endpoint: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="检查版本信息时发生内部错误")
|
||||
146
app/router/vertex_express_routes.py
Normal file
146
app/router/vertex_express_routes.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from copy import deepcopy
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_vertex_express_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.service.chat.vertex_express_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
|
||||
router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
|
||||
logger = get_vertex_express_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
"""获取密钥管理器实例"""
|
||||
return await get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取下一个可用的API密钥"""
|
||||
return await key_manager.get_next_working_vertex_key()
|
||||
|
||||
|
||||
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取Gemini聊天服务实例"""
|
||||
return GeminiChatService(settings.VERTEX_EXPRESS_BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
logger.info("Handling Gemini models list request")
|
||||
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
models_json["models"].append(item)
|
||||
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
|
||||
logger.info("Gemini models list request successful")
|
||||
return models_json
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
@@ -1,100 +0,0 @@
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.domain.gemini_models import GeminiRequest, GeminiContent
|
||||
from app.config.config import settings
|
||||
from app.log.logger import Logger # 导入 Logger 类
|
||||
|
||||
logger = Logger.setup_logger("scheduler") # 使用 Logger.setup_logger
|
||||
|
||||
async def check_failed_keys():
|
||||
"""
|
||||
定时检查失败次数大于0的API密钥,并尝试验证它们。
|
||||
如果验证成功,重置失败计数;如果失败,增加失败计数。
|
||||
"""
|
||||
logger.info("Starting scheduled check for failed API keys...")
|
||||
try:
|
||||
key_manager = await get_key_manager_instance()
|
||||
# 确保 KeyManager 已经初始化
|
||||
if not key_manager or not hasattr(key_manager, 'key_failure_counts'):
|
||||
logger.warning("KeyManager instance not available or not initialized. Skipping check.")
|
||||
return
|
||||
|
||||
# 创建 GeminiChatService 实例用于验证
|
||||
# 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务
|
||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
# 获取需要检查的 key 列表 (失败次数 > 0)
|
||||
keys_to_check = []
|
||||
async with key_manager.failure_count_lock: # 访问共享数据需要加锁
|
||||
# 复制一份以避免在迭代时修改字典
|
||||
failure_counts_copy = key_manager.key_failure_counts.copy()
|
||||
keys_to_check = [key for key, count in failure_counts_copy.items() if count > 0] # 检查所有失败次数大于0的key
|
||||
|
||||
if not keys_to_check:
|
||||
logger.info("No keys with failure count > 0 found. Skipping verification.")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(keys_to_check)} keys with failure count > 0 to verify.")
|
||||
|
||||
for key in keys_to_check:
|
||||
# 隐藏部分 key 用于日志记录
|
||||
log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key
|
||||
logger.info(f"Verifying key: {log_key}...")
|
||||
try:
|
||||
# 构造测试请求
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=[{"text": "hi"}] # 使用简单的文本进行验证
|
||||
)
|
||||
]
|
||||
)
|
||||
# 调用 generate_content 进行验证
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL, # 使用配置中定义的测试模型
|
||||
gemini_request,
|
||||
key
|
||||
)
|
||||
# 如果没有抛出异常,说明 key 有效
|
||||
logger.info(f"Key {log_key} verification successful. Resetting failure count.")
|
||||
await key_manager.reset_key_failure_count(key)
|
||||
except Exception as e:
|
||||
# 验证失败,增加失败计数
|
||||
logger.warning(f"Key {log_key} verification failed: {str(e)}. Incrementing failure count.")
|
||||
# 直接操作计数器,需要加锁
|
||||
async with key_manager.failure_count_lock:
|
||||
# 再次检查 key 是否存在且失败次数未达上限
|
||||
if key in key_manager.key_failure_counts and key_manager.key_failure_counts[key] < key_manager.MAX_FAILURES:
|
||||
key_manager.key_failure_counts[key] += 1
|
||||
logger.info(f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}.")
|
||||
elif key in key_manager.key_failure_counts:
|
||||
logger.warning(f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further.")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the scheduled key check: {str(e)}", exc_info=True)
|
||||
|
||||
def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
# 添加定时任务,例如每小时执行一次 (可以调整)
|
||||
scheduler.add_job(check_failed_keys, 'interval', hours=settings.CHECK_INTERVAL_HOURS)
|
||||
scheduler.start()
|
||||
logger.info(f"Scheduler started. Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s).")
|
||||
return scheduler
|
||||
|
||||
# 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止
|
||||
scheduler_instance = None
|
||||
|
||||
def start_scheduler():
|
||||
global scheduler_instance
|
||||
if scheduler_instance is None:
|
||||
scheduler_instance = setup_scheduler()
|
||||
|
||||
def stop_scheduler():
|
||||
global scheduler_instance
|
||||
if scheduler_instance and scheduler_instance.running:
|
||||
scheduler_instance.shutdown()
|
||||
logger.info("Scheduler stopped.")
|
||||
194
app/scheduler/scheduled_tasks.py
Normal file
194
app/scheduler/scheduled_tasks.py
Normal file
@@ -0,0 +1,194 @@
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||
from app.log.logger import Logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.error_log.error_log_service import delete_old_error_logs
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
||||
from app.service.files.files_service import get_files_service
|
||||
|
||||
logger = Logger.setup_logger("scheduler")
|
||||
|
||||
|
||||
async def check_failed_keys():
|
||||
"""
|
||||
定时检查失败次数大于0的API密钥,并尝试验证它们。
|
||||
如果验证成功,重置失败计数;如果失败,增加失败计数。
|
||||
"""
|
||||
logger.info("Starting scheduled check for failed API keys...")
|
||||
try:
|
||||
key_manager = await get_key_manager_instance()
|
||||
# 确保 KeyManager 已经初始化
|
||||
if not key_manager or not hasattr(key_manager, "key_failure_counts"):
|
||||
logger.warning(
|
||||
"KeyManager instance not available or not initialized. Skipping check."
|
||||
)
|
||||
return
|
||||
|
||||
# 创建 GeminiChatService 实例用于验证
|
||||
# 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务
|
||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
# 获取需要检查的 key 列表 (失败次数 > 0)
|
||||
keys_to_check = []
|
||||
async with key_manager.failure_count_lock: # 访问共享数据需要加锁
|
||||
# 复制一份以避免在迭代时修改字典
|
||||
failure_counts_copy = key_manager.key_failure_counts.copy()
|
||||
keys_to_check = [
|
||||
key for key, count in failure_counts_copy.items() if count > 0
|
||||
] # 检查所有失败次数大于0的key
|
||||
|
||||
if not keys_to_check:
|
||||
logger.info("No keys with failure count > 0 found. Skipping verification.")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Found {len(keys_to_check)} keys with failure count > 0 to verify."
|
||||
)
|
||||
|
||||
for key in keys_to_check:
|
||||
# 隐藏部分 key 用于日志记录
|
||||
log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key
|
||||
logger.info(f"Verifying key: {log_key}...")
|
||||
try:
|
||||
# 构造测试请求
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=[{"text": "hi"}],
|
||||
)
|
||||
]
|
||||
)
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL, gemini_request, key
|
||||
)
|
||||
logger.info(
|
||||
f"Key {log_key} verification successful. Resetting failure count."
|
||||
)
|
||||
await key_manager.reset_key_failure_count(key)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Key {log_key} verification failed: {str(e)}. Incrementing failure count."
|
||||
)
|
||||
# 直接操作计数器,需要加锁
|
||||
async with key_manager.failure_count_lock:
|
||||
# 再次检查 key 是否存在且失败次数未达上限
|
||||
if (
|
||||
key in key_manager.key_failure_counts
|
||||
and key_manager.key_failure_counts[key]
|
||||
< key_manager.MAX_FAILURES
|
||||
):
|
||||
key_manager.key_failure_counts[key] += 1
|
||||
logger.info(
|
||||
f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}."
|
||||
)
|
||||
elif key in key_manager.key_failure_counts:
|
||||
logger.warning(
|
||||
f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled key check: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_expired_files():
|
||||
"""
|
||||
定时清理过期的文件记录
|
||||
"""
|
||||
logger.info("Starting scheduled cleanup for expired files...")
|
||||
try:
|
||||
files_service = await get_files_service()
|
||||
deleted_count = await files_service.cleanup_expired_files()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
|
||||
else:
|
||||
logger.info("No expired files to clean up.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
# 添加检查失败密钥的定时任务
|
||||
scheduler.add_job(
|
||||
check_failed_keys,
|
||||
"interval",
|
||||
hours=settings.CHECK_INTERVAL_HOURS,
|
||||
id="check_failed_keys_job",
|
||||
name="Check Failed API Keys",
|
||||
)
|
||||
logger.info(
|
||||
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
||||
)
|
||||
|
||||
# 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
|
||||
scheduler.add_job(
|
||||
delete_old_error_logs,
|
||||
"cron",
|
||||
hour=3,
|
||||
minute=0,
|
||||
id="delete_old_error_logs_job",
|
||||
name="Delete Old Error Logs",
|
||||
)
|
||||
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
|
||||
|
||||
# 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
|
||||
scheduler.add_job(
|
||||
delete_old_request_logs_task,
|
||||
"cron",
|
||||
hour=3,
|
||||
minute=5,
|
||||
id="delete_old_request_logs_job",
|
||||
name="Delete Old Request Logs",
|
||||
)
|
||||
logger.info(
|
||||
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
||||
)
|
||||
|
||||
# 新增:添加文件过期清理的定时任务,每小时执行一次
|
||||
if getattr(settings, 'FILES_CLEANUP_ENABLED', True):
|
||||
cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1)
|
||||
scheduler.add_job(
|
||||
cleanup_expired_files,
|
||||
"interval",
|
||||
hours=cleanup_interval,
|
||||
id="cleanup_expired_files_job",
|
||||
name="Cleanup Expired Files",
|
||||
)
|
||||
logger.info(
|
||||
f"File cleanup job scheduled to run every {cleanup_interval} hour(s)."
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started with all jobs.")
|
||||
return scheduler
|
||||
|
||||
|
||||
# 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止
|
||||
scheduler_instance = None
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
global scheduler_instance
|
||||
if scheduler_instance is None or not scheduler_instance.running:
|
||||
logger.info("Starting scheduler...")
|
||||
scheduler_instance = setup_scheduler()
|
||||
logger.info("Scheduler is already running.")
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
global scheduler_instance
|
||||
if scheduler_instance and scheduler_instance.running:
|
||||
scheduler_instance.shutdown()
|
||||
logger.info("Scheduler stopped.")
|
||||
@@ -2,17 +2,18 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import datetime # Add datetime import
|
||||
import time # Add time import
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log # Import add_request_log
|
||||
from app.database.services import add_error_log, add_request_log, get_file_api_key
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -26,6 +27,55 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
|
||||
"""從內容中提取文件引用"""
|
||||
file_names = []
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if not isinstance(part, dict) or "fileData" not in part:
|
||||
continue
|
||||
file_data = part["fileData"]
|
||||
if "fileUri" not in file_data:
|
||||
continue
|
||||
file_uri = file_data["fileUri"]
|
||||
# 從 URI 中提取文件名
|
||||
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
|
||||
match = re.match(rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri)
|
||||
if not match:
|
||||
logger.warning(f"Invalid file URI: {file_uri}")
|
||||
continue
|
||||
file_id = match.group(1)
|
||||
file_names.append(file_id)
|
||||
logger.info(f"Found file reference: {file_id}")
|
||||
return file_names
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
|
||||
"contentEncoding", "contentMediaType", "if", "then", "else",
|
||||
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
|
||||
"$id", "$ref", "$comment", "readOnly", "writeOnly"
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
@@ -39,7 +89,15 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
functions.extend(v)
|
||||
# 清理每个函数声明中的不支持字段
|
||||
cleaned_functions = []
|
||||
for func in v:
|
||||
if isinstance(func, dict):
|
||||
cleaned_func = _clean_json_schema_properties(func)
|
||||
cleaned_functions.append(cleaned_func)
|
||||
else:
|
||||
cleaned_functions.append(func)
|
||||
functions.extend(cleaned_functions)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
@@ -61,58 +119,122 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filters out contents with empty or invalid parts."""
|
||||
if not contents:
|
||||
return []
|
||||
|
||||
filtered_contents = []
|
||||
for content in contents:
|
||||
if not content or "parts" not in content or not isinstance(content.get("parts"), list):
|
||||
continue
|
||||
|
||||
valid_parts = [part for part in content["parts"] if isinstance(part, dict) and part]
|
||||
|
||||
if valid_parts:
|
||||
new_content = content.copy()
|
||||
new_content["parts"] = valid_parts
|
||||
filtered_contents.append(new_content)
|
||||
|
||||
return filtered_contents
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
"systemInstruction": request_dict.get("systemInstruction", ""),
|
||||
}
|
||||
if "maxOutputTokens" in request_dict["generationConfig"]:
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
# 检查是否为TTS模型
|
||||
is_tts_model = "tts" in model.lower()
|
||||
|
||||
if is_tts_model:
|
||||
# TTS模型使用简化的payload,不包含tools和safetySettings
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
else:
|
||||
# 非TTS模型使用完整的payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
|
||||
client_thinking_config = None
|
||||
if request.generationConfig and request.generationConfig.thinkingConfig:
|
||||
client_thinking_config = request.generationConfig.thinkingConfig
|
||||
|
||||
if client_thinking_config is not None:
|
||||
# 客户端提供了思考配置,直接使用
|
||||
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
||||
else:
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
if model.endswith("-non-thinking"):
|
||||
if "gemini-2.5-pro" in model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif model in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
|
||||
"includeThoughts": True
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
return payload
|
||||
|
||||
@@ -142,7 +264,7 @@ class GeminiChatService:
|
||||
self, original_response: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的响应"""
|
||||
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
||||
response_copy = json.loads(json.dumps(original_response))
|
||||
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
||||
"content", {}
|
||||
).get("parts"):
|
||||
@@ -153,9 +275,20 @@ class GeminiChatService:
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
payload = _build_payload(model, request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now() # Record request time
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
@@ -163,20 +296,18 @@ class GeminiChatService:
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 on success
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
# Try to parse status code from exception
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default to 500 if parsing fails
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
@@ -185,11 +316,58 @@ class GeminiChatService:
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
raise e # Re-throw exception for upstream handling
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def count_tokens(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""计算token数量"""
|
||||
# countTokens API只需要contents
|
||||
payload = {"contents": _filter_empty_parts(request.model_dump().get("contents", []))}
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.count_tokens(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-count-tokens",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
# Log request to request log table
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
@@ -203,6 +381,17 @@ class GeminiChatService:
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
payload = _build_payload(model, request)
|
||||
@@ -214,7 +403,7 @@ class GeminiChatService:
|
||||
request_datetime = datetime.datetime.now()
|
||||
start_time = time.perf_counter()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key # Update final key used
|
||||
final_api_key = current_attempt_key
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, current_attempt_key
|
||||
@@ -251,16 +440,14 @@ class GeminiChatService:
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
# Parse error code for logging
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key, # Log key used for this failed attempt
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
@@ -268,28 +455,26 @@ class GeminiChatService:
|
||||
request_msg=payload
|
||||
)
|
||||
|
||||
# Attempt to switch API Key
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else: # No more keys or retries exceeded by handle_api_failure logic
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break # Exit loop if no key available
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
)
|
||||
break # Exit loop after max retries
|
||||
break
|
||||
finally:
|
||||
# Log the final outcome of the streaming request
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key, # Log the last key used
|
||||
is_success=is_success, # Log the final success status
|
||||
status_code=status_code, # Log the last known status code
|
||||
latency_ms=latency_ms, # Log total time including retries
|
||||
api_key=final_api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import datetime # Add datetime import
|
||||
import time # Add time import
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.handler.message_converter import OpenAIMessageConverter
|
||||
from app.handler.response_handler import OpenAIResponseHandler
|
||||
@@ -16,21 +22,47 @@ from app.log.logger import get_openai_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log # Import add_request_log
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片部分"""
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
def _has_media_parts(messages: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含多媒体部分"""
|
||||
for message in messages:
|
||||
if "parts" in message:
|
||||
for part in message["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
|
||||
"contentEncoding", "contentMediaType", "if", "then", "else",
|
||||
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
|
||||
"$id", "$ref", "$comment", "readOnly", "writeOnly"
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -46,11 +78,19 @@ def _build_tools(
|
||||
or model.endswith("-image")
|
||||
or model.endswith("-image-generation")
|
||||
)
|
||||
and not _has_image_parts(messages)
|
||||
and not _has_media_parts(messages)
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
logger.debug("Code execution tool enabled.")
|
||||
elif _has_media_parts(messages):
|
||||
logger.debug("Code execution tool disabled due to media parts presence.")
|
||||
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 将 request 中的 tools 合并到 tools 中
|
||||
if request.tools:
|
||||
@@ -62,9 +102,13 @@ def _build_tools(
|
||||
if item.get("type", "") == "function" and item.get("function"):
|
||||
function = deepcopy(item.get("function"))
|
||||
parameters = function.get("parameters", {})
|
||||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||||
if parameters.get("type") == "object" and not parameters.get(
|
||||
"properties", {}
|
||||
):
|
||||
function.pop("parameters", None)
|
||||
|
||||
# 清理函数中的不支持字段
|
||||
function = _clean_json_schema_properties(function)
|
||||
function_declarations.append(function)
|
||||
|
||||
if function_declarations:
|
||||
@@ -72,8 +116,13 @@ def _build_tools(
|
||||
names, functions = set(), []
|
||||
for fc in function_declarations:
|
||||
if fc.get("name") not in names:
|
||||
names.add(fc.get("name"))
|
||||
functions.append(fc)
|
||||
if fc.get("name")=="googleSearch":
|
||||
# cherry开启内置搜索时,添加googleSearch工具
|
||||
tool["googleSearch"] = {}
|
||||
else:
|
||||
# 其他函数,添加到functionDeclarations中
|
||||
names.add(fc.get("name"))
|
||||
functions.append(fc)
|
||||
|
||||
tool["functionDeclarations"] = functions
|
||||
|
||||
@@ -81,10 +130,23 @@ def _build_tools(
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext",None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
# if (
|
||||
@@ -93,20 +155,25 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
# and "gemini-2.0-pro-exp" not in model
|
||||
# ):
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _validate_and_set_max_tokens(
|
||||
payload: Dict[str, Any],
|
||||
max_tokens: Optional[int],
|
||||
logger_instance
|
||||
) -> None:
|
||||
"""验证并设置 max_tokens 参数"""
|
||||
if max_tokens is None:
|
||||
return
|
||||
|
||||
# 参数验证和处理
|
||||
if max_tokens <= 0:
|
||||
logger_instance.warning(f"Invalid max_tokens value: {max_tokens}, will not set maxOutputTokens")
|
||||
# 不设置 maxOutputTokens,让 Gemini API 使用默认值
|
||||
else:
|
||||
payload["generationConfig"]["maxOutputTokens"] = max_tokens
|
||||
|
||||
|
||||
def _build_payload(
|
||||
@@ -126,14 +193,27 @@ def _build_payload(
|
||||
"tools": _build_tools(request, messages),
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.max_tokens is not None:
|
||||
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||||
|
||||
# 处理 max_tokens 参数
|
||||
_validate_and_set_max_tokens(payload, request.max_tokens, logger)
|
||||
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if request.model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if "gemini-2.5-pro" in request.model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
|
||||
if request.model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model,1000)}
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000),
|
||||
"includeThoughts": True
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)}
|
||||
|
||||
if (
|
||||
instruction
|
||||
@@ -172,7 +252,7 @@ class OpenAIChatService:
|
||||
self, original_chunk: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的OpenAI响应块"""
|
||||
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
||||
chunk_copy = json.loads(json.dumps(original_chunk))
|
||||
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
||||
chunk_copy["choices"][0]["delta"]["content"] = text
|
||||
return chunk_copy
|
||||
@@ -183,10 +263,8 @@ class OpenAIChatService:
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
|
||||
# 构建请求payload
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
if request.stream:
|
||||
@@ -202,49 +280,190 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
usage_metadata = response.get("usageMetadata", {})
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 on success
|
||||
return self.response_handler.handle_response(
|
||||
response, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
status_code = 200
|
||||
|
||||
# 尝试处理响应,捕获可能的响应处理异常
|
||||
try:
|
||||
result = self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
return result
|
||||
except Exception as response_error:
|
||||
logger.error(f"Response processing failed for model {model}: {str(response_error)}")
|
||||
|
||||
# 记录详细的错误信息
|
||||
if "parts" in str(response_error):
|
||||
logger.error("Response structure issue - missing or invalid parts")
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
logger.error(f"Content structure: {content}")
|
||||
|
||||
# 重新抛出异常
|
||||
raise response_error
|
||||
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
# Try to parse status code from exception
|
||||
logger.error(f"API call failed for model {model}: {error_log_msg}")
|
||||
|
||||
# 特别记录 max_tokens 相关的错误
|
||||
gen_config = payload.get('generationConfig', {})
|
||||
if "maxOutputTokens" in gen_config:
|
||||
logger.error(f"Request had maxOutputTokens: {gen_config['maxOutputTokens']}")
|
||||
|
||||
# 如果是响应处理错误,记录更多信息
|
||||
if "parts" in error_log_msg:
|
||||
logger.error("This is likely a response processing error")
|
||||
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default if parsing fails
|
||||
status_code = int(match.group(1)) if match else 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key, # Note: Parameter name is gemini_key in add_error_log
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=payload,
|
||||
)
|
||||
raise e # Re-throw exception
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Normal completion finished - Success: {is_success}, Latency: {latency_ms}ms")
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _fake_stream_logic_impl(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理伪流式 (fake stream) 的核心逻辑"""
|
||||
logger.info(
|
||||
f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint."
|
||||
)
|
||||
keep_sending_empty_data = True
|
||||
|
||||
async def send_empty_data_locally() -> AsyncGenerator[str, None]:
|
||||
"""定期发送空数据以保持连接"""
|
||||
while keep_sending_empty_data:
|
||||
await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
|
||||
if keep_sending_empty_data:
|
||||
empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
||||
yield f"data: {json.dumps(empty_chunk)}\n\n"
|
||||
logger.debug("Sent empty data chunk for fake stream heartbeat.")
|
||||
|
||||
empty_data_generator = send_empty_data_locally()
|
||||
api_response_task = asyncio.create_task(
|
||||
self.api_client.generate_content(payload, model, api_key)
|
||||
)
|
||||
|
||||
try:
|
||||
while not api_response_task.done():
|
||||
try:
|
||||
next_empty_chunk = await asyncio.wait_for(
|
||||
empty_data_generator.__anext__(), timeout=0.1
|
||||
)
|
||||
yield next_empty_chunk
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except (
|
||||
StopAsyncIteration
|
||||
):
|
||||
break
|
||||
|
||||
response = await api_response_task
|
||||
finally:
|
||||
keep_sending_empty_data = False
|
||||
|
||||
if response and response.get("candidates"):
|
||||
response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
|
||||
yield f"data: {json.dumps(response)}\n\n"
|
||||
logger.info(f"Sent full response content for fake stream: {model}")
|
||||
else:
|
||||
error_message = "Failed to get response from model"
|
||||
if (
|
||||
response and isinstance(response, dict) and response.get("error")
|
||||
):
|
||||
error_details = response.get("error")
|
||||
if isinstance(error_details, dict):
|
||||
error_message = error_details.get("message", error_message)
|
||||
|
||||
logger.error(
|
||||
f"No candidates or error in response for fake stream model {model}: {response}"
|
||||
)
|
||||
error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
||||
|
||||
async def _real_stream_logic_impl(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理真实流式 (real stream) 的核心逻辑"""
|
||||
tool_call_flag = False
|
||||
usage_metadata = None
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, api_key
|
||||
):
|
||||
if line.startswith("data:"):
|
||||
chunk_str = line[6:]
|
||||
if not chunk_str or chunk_str.isspace():
|
||||
logger.debug(
|
||||
f"Received empty data line for model {model}, skipping."
|
||||
)
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(chunk_str)
|
||||
usage_metadata = chunk.get("usageMetadata", {})
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Failed to decode JSON from stream for model {model}: {chunk_str}"
|
||||
)
|
||||
continue
|
||||
openai_chunk = self.response_handler.handle_response(
|
||||
chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
|
||||
)
|
||||
if openai_chunk:
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
||||
async for (
|
||||
optimized_chunk_data
|
||||
) in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||
):
|
||||
yield optimized_chunk_data
|
||||
else:
|
||||
if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
|
||||
tool_call_flag = True
|
||||
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
|
||||
if tool_call_flag:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n"
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
"""处理流式聊天完成,添加重试逻辑和假流式支持"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
is_success = False
|
||||
@@ -254,110 +473,107 @@ class OpenAIChatService:
|
||||
while retries < max_retries:
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key
|
||||
current_attempt_key = final_api_key
|
||||
|
||||
try:
|
||||
tool_call_flag = False
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, current_attempt_key
|
||||
):
|
||||
if line.startswith("data:"):
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = self.response_handler.handle_response(
|
||||
chunk, model, stream=True, finish_reason=None
|
||||
)
|
||||
if openai_chunk:
|
||||
# 提取文本内容
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(
|
||||
openai_chunk, t
|
||||
),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如工具调用等),整块输出
|
||||
if "tool_calls" in json.dumps(openai_chunk):
|
||||
tool_call_flag = True
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
if tool_call_flag:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls'))}\n\n"
|
||||
stream_generator = None
|
||||
if settings.FAKE_STREAM_ENABLED:
|
||||
logger.info(
|
||||
f"Using fake stream logic for model: {model}, Attempt: {retries + 1}"
|
||||
)
|
||||
stream_generator = self._fake_stream_logic_impl(
|
||||
model, payload, current_attempt_key
|
||||
)
|
||||
else:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
logger.info(
|
||||
f"Using real stream logic for model: {model}, Attempt: {retries + 1}"
|
||||
)
|
||||
stream_generator = self._real_stream_logic_impl(
|
||||
model, payload, current_attempt_key
|
||||
)
|
||||
|
||||
async for chunk_data in stream_generator:
|
||||
yield chunk_data
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
logger.info(
|
||||
f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 on success
|
||||
break # 成功后退出循环
|
||||
status_code = 200
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
|
||||
)
|
||||
# Parse error code for logging
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
|
||||
match = re.search(r"status code (\\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default if parsing fails
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
status_code = 408
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=payload,
|
||||
)
|
||||
|
||||
# Attempt to switch API Key
|
||||
# Ensure key_manager is available (might need adjustment if not always passed)
|
||||
if self.key_manager:
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break # Exit loop if no key available
|
||||
new_api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if new_api_key and new_api_key != current_attempt_key:
|
||||
final_api_key = new_api_key
|
||||
logger.info(
|
||||
f"Switched to new API key for next attempt: {final_api_key}"
|
||||
)
|
||||
elif not new_api_key:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries, ceasing attempts for this request."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break # Exit loop if key manager is missing
|
||||
logger.error(
|
||||
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
|
||||
)
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
f"Max retries ({max_retries}) reached for streaming model {model}."
|
||||
)
|
||||
break # Exit loop after max retries
|
||||
finally:
|
||||
# Log the final outcome of the streaming request
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key, # Log the last key used
|
||||
is_success=is_success, # Log the final success status
|
||||
status_code=status_code, # Log the last known status code
|
||||
latency_ms=latency_ms, # Log total time including retries
|
||||
request_time=request_datetime
|
||||
api_key=current_attempt_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
# If the loop finished due to failure, yield error and DONE
|
||||
if not is_success and retries >= max_retries:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
if not is_success:
|
||||
logger.error(
|
||||
f"Streaming failed permanently for model {model} after {retries} attempts."
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str
|
||||
self, request: ChatRequest, api_key: str
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
|
||||
image_generate_request = ImageGenerationRequest()
|
||||
@@ -367,18 +583,22 @@ class OpenAIChatService:
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_image_completion(request.model, image_res, api_key)
|
||||
return self._handle_stream_image_completion(
|
||||
request.model, image_res, api_key
|
||||
)
|
||||
else:
|
||||
return await self._handle_normal_image_completion(request.model, image_res, api_key)
|
||||
return await self._handle_normal_image_completion(
|
||||
request.model, image_res, api_key
|
||||
)
|
||||
|
||||
async def _handle_stream_image_completion(
|
||||
self, model: str, image_data: str, api_key:str
|
||||
self, model: str, image_data: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
logger.info(f"Starting stream image completion for model: {model}")
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now() # Although not used for DB log here
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None # Although not used for DB log here
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
if image_data:
|
||||
@@ -402,7 +622,9 @@ class OpenAIChatService:
|
||||
# 如果没有文本内容(如图片URL等),整块输出
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
logger.info(f"Stream image completion finished successfully for model: {model}")
|
||||
logger.info(
|
||||
f"Stream image completion finished successfully for model: {model}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -410,48 +632,49 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500 # Default error code
|
||||
# Call add_error_log using the passed api_key
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-stream", # Specific error type
|
||||
error_type="openai-image-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
)
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n" # Send error to client
|
||||
yield "data: [DONE]\n\n" # Still need DONE message
|
||||
# Re-raising might break the stream, decide if needed
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}")
|
||||
# Call add_request_log using the passed api_key
|
||||
logger.info(
|
||||
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_normal_image_completion(
|
||||
self, model: str, image_data: str, api_key: str # Add api_key parameter
|
||||
self, model: str, image_data: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"Starting normal image completion for model: {model}")
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now() # Although not used for DB log here
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None # Although not used for DB log here
|
||||
status_code = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
result = self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
logger.info(f"Normal image completion finished successfully for model: {model}")
|
||||
logger.info(
|
||||
f"Normal image completion finished successfully for model: {model}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return result
|
||||
@@ -459,28 +682,27 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500 # Default error code
|
||||
# Call add_error_log using the passed api_key
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-non-stream", # Specific error type
|
||||
error_type="openai-image-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
)
|
||||
# Re-raise the exception so the caller knows about the failure
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}")
|
||||
# Call add_request_log using the passed api_key
|
||||
logger.info(
|
||||
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
348
app/service/chat/vertex_express_chat_service.py
Normal file
348
app/service/chat/vertex_express_chat_service.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import json
|
||||
import re
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片部分"""
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
|
||||
"contentEncoding", "contentMediaType", "if", "then", "else",
|
||||
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
|
||||
"$id", "$ref", "$comment", "readOnly", "writeOnly"
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
|
||||
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
record = dict()
|
||||
for item in tools:
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
# 清理每个函数声明中的不支持字段
|
||||
cleaned_functions = []
|
||||
for func in v:
|
||||
if isinstance(func, dict):
|
||||
cleaned_func = _clean_json_schema_properties(func)
|
||||
cleaned_functions.append(cleaned_func)
|
||||
else:
|
||||
cleaned_functions.append(func)
|
||||
functions.extend(cleaned_functions)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
return record
|
||||
|
||||
tool = dict()
|
||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
||||
payload["tools"] = [payload.get("tools")]
|
||||
items = payload.get("tools", [])
|
||||
if items and isinstance(items, list):
|
||||
tool.update(_merge_tools(items))
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
|
||||
client_thinking_config = None
|
||||
if request.generationConfig and request.generationConfig.thinkingConfig:
|
||||
client_thinking_config = request.generationConfig.thinkingConfig
|
||||
|
||||
if client_thinking_config is not None:
|
||||
# 客户端提供了思考配置,直接使用
|
||||
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
||||
else:
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
if model.endswith("-non-thinking"):
|
||||
if "gemini-2.5-pro" in model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif model in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
|
||||
"includeThoughts": True
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class GeminiChatService:
|
||||
"""聊天服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager):
|
||||
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
||||
self.key_manager = key_manager
|
||||
self.response_handler = GeminiResponseHandler()
|
||||
|
||||
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
||||
"""从响应中提取文本内容"""
|
||||
if not response.get("candidates"):
|
||||
return ""
|
||||
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts and "text" in parts[0]:
|
||||
return parts[0].get("text", "")
|
||||
return ""
|
||||
|
||||
def _create_char_response(
|
||||
self, original_response: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的响应"""
|
||||
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
||||
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
||||
"content", {}
|
||||
).get("parts"):
|
||||
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
||||
return response_copy
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
payload = _build_payload(model, request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
payload = _build_payload(model, request)
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
|
||||
while retries < max_retries:
|
||||
request_datetime = datetime.datetime.now()
|
||||
start_time = time.perf_counter()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key # Update final key used
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, current_attempt_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
line = line[6:]
|
||||
response_data = self.response_handler.handle_response(
|
||||
json.loads(line), model, stream=True
|
||||
)
|
||||
text = self._extract_text_from_response(response_data)
|
||||
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
|
||||
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in gemini_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_response(response_data, t),
|
||||
lambda c: "data: " + json.dumps(c) + "\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如工具调用等),整块输出
|
||||
yield "data: " + json.dumps(response_data) + "\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
is_success = True
|
||||
status_code = 200
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
)
|
||||
break
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
@@ -1,11 +1,14 @@
|
||||
# app/services/chat/api_client.py
|
||||
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
from typing import Dict, Any, AsyncGenerator, Optional
|
||||
import httpx
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_api_client_logger
|
||||
from app.core.constants import DEFAULT_TIMEOUT
|
||||
|
||||
logger = get_api_client_logger()
|
||||
|
||||
class ApiClient(ABC):
|
||||
"""API客户端基类"""
|
||||
@@ -37,28 +40,243 @@ class GeminiApiClient(ApiClient):
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
def _prepare_headers(self) -> Dict[str, str]:
|
||||
headers = {}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
headers.update(settings.CUSTOM_HEADERS)
|
||||
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
|
||||
return headers
|
||||
|
||||
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取可用的 Gemini 模型列表"""
|
||||
timeout = httpx.Timeout(timeout=5)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models?key={api_key}&pageSize=1000"
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败: {e.response.status_code}")
|
||||
logger.error(e.response.text)
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求模型列表失败: {e}")
|
||||
return None
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应结构的基本信息
|
||||
if not response_data.get("candidates"):
|
||||
logger.warning("No candidates found in API response")
|
||||
|
||||
return response_data
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for counting tokens: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:countTokens?key={api_key}"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
def _prepare_headers(self, api_key: str) -> Dict[str, str]:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
headers.update(settings.CUSTOM_HEADERS)
|
||||
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
|
||||
return headers
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/models"
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}")
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/embeddings"
|
||||
payload = {
|
||||
"input": input,
|
||||
"model": model,
|
||||
}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/images/generations"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
@@ -1,41 +1,49 @@
|
||||
"""
|
||||
配置服务模块
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import insert, update
|
||||
|
||||
from app.config.config import Settings as ConfigSettings
|
||||
from app.config.config import settings
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings
|
||||
from app.config.config import Settings as ConfigSettings
|
||||
from app.database.services import get_all_settings
|
||||
from app.service.key.key_manager import get_key_manager_instance, reset_key_manager_instance
|
||||
from app.log.logger import get_config_routes_logger
|
||||
from app.service.key.key_manager import (
|
||||
get_key_manager_instance,
|
||||
reset_key_manager_instance,
|
||||
)
|
||||
from app.service.model.model_service import ModelService
|
||||
|
||||
logger = get_config_routes_logger()
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""配置服务类,用于管理应用程序配置"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_config() -> Dict[str, Any]:
|
||||
return settings.model_dump()
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def update_config(config_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
for key, value in config_data.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
logger.info(f"Updated setting in memory: {key}")
|
||||
|
||||
logger.debug(f"Updated setting in memory: {key}")
|
||||
|
||||
# 获取现有设置
|
||||
existing_settings_raw: List[Dict[str, Any]] = await get_all_settings()
|
||||
existing_settings_map: Dict[str, Dict[str, Any]] = {s['key']: s for s in existing_settings_raw}
|
||||
existing_settings_map: Dict[str, Dict[str, Any]] = {
|
||||
s["key"]: s for s in existing_settings_raw
|
||||
}
|
||||
existing_keys = set(existing_settings_map.keys())
|
||||
|
||||
settings_to_update: List[Dict[str, Any]] = []
|
||||
@@ -47,7 +55,7 @@ class ConfigService:
|
||||
# 处理不同类型的值
|
||||
if isinstance(value, list):
|
||||
db_value = json.dumps(value)
|
||||
elif isinstance(value, dict): # 新增对 dict 类型的处理
|
||||
elif isinstance(value, dict):
|
||||
db_value = json.dumps(value)
|
||||
elif isinstance(value, bool):
|
||||
db_value = str(value).lower()
|
||||
@@ -55,24 +63,25 @@ class ConfigService:
|
||||
db_value = str(value)
|
||||
|
||||
# 仅当值发生变化时才更新
|
||||
if key in existing_keys and existing_settings_map[key]['value'] == db_value:
|
||||
continue
|
||||
if key in existing_keys and existing_settings_map[key]["value"] == db_value:
|
||||
continue
|
||||
|
||||
description = f"{key}配置项"
|
||||
description = f"{key}配置项"
|
||||
|
||||
data = {
|
||||
'key': key,
|
||||
'value': db_value,
|
||||
'description': description,
|
||||
'updated_at': now
|
||||
"key": key,
|
||||
"value": db_value,
|
||||
"description": description,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
if key in existing_keys:
|
||||
# Preserve original description if not explicitly provided
|
||||
data['description'] = existing_settings_map[key].get('description', description)
|
||||
data["description"] = existing_settings_map[key].get(
|
||||
"description", description
|
||||
)
|
||||
settings_to_update.append(data)
|
||||
else:
|
||||
data['created_at'] = now
|
||||
data["created_at"] = now
|
||||
settings_to_insert.append(data)
|
||||
|
||||
# 在事务中执行批量插入和更新
|
||||
@@ -82,37 +91,109 @@ class ConfigService:
|
||||
if settings_to_insert:
|
||||
query_insert = insert(Settings).values(settings_to_insert)
|
||||
await database.execute(query=query_insert)
|
||||
logger.info(f"Bulk inserted {len(settings_to_insert)} settings.")
|
||||
logger.info(
|
||||
f"Bulk inserted {len(settings_to_insert)} settings."
|
||||
)
|
||||
|
||||
if settings_to_update:
|
||||
for setting_data in settings_to_update:
|
||||
query_update = (
|
||||
update(Settings)
|
||||
.where(Settings.key == setting_data['key'])
|
||||
.where(Settings.key == setting_data["key"])
|
||||
.values(
|
||||
value=setting_data['value'],
|
||||
description=setting_data['description'],
|
||||
updated_at=setting_data['updated_at']
|
||||
value=setting_data["value"],
|
||||
description=setting_data["description"],
|
||||
updated_at=setting_data["updated_at"],
|
||||
)
|
||||
)
|
||||
await database.execute(query=query_update)
|
||||
logger.info(f"Updated {len(settings_to_update)} settings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to bulk update/insert settings: {str(e)}")
|
||||
raise # Re-raise the exception after logging
|
||||
raise
|
||||
|
||||
# 重置并重新初始化 KeyManager
|
||||
try:
|
||||
await reset_key_manager_instance()
|
||||
await get_key_manager_instance(settings.API_KEYS)
|
||||
await get_key_manager_instance(settings.API_KEYS, settings.VERTEX_API_KEYS)
|
||||
logger.info("KeyManager instance re-initialized with updated settings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to re-initialize KeyManager: {str(e)}")
|
||||
# Decide if this error should prevent returning the updated config
|
||||
# For now, we log the error and continue
|
||||
|
||||
return await ConfigService.get_config()
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def delete_key(key_to_delete: str) -> Dict[str, Any]:
|
||||
"""删除单个API密钥"""
|
||||
# 确保 settings.API_KEYS 是一个列表
|
||||
if not isinstance(settings.API_KEYS, list):
|
||||
settings.API_KEYS = []
|
||||
|
||||
original_keys_count = len(settings.API_KEYS)
|
||||
# 创建一个不包含待删除密钥的新列表
|
||||
updated_api_keys = [k for k in settings.API_KEYS if k != key_to_delete]
|
||||
|
||||
if len(updated_api_keys) < original_keys_count:
|
||||
# 密钥已找到并从列表中移除
|
||||
settings.API_KEYS = updated_api_keys # 首先更新内存中的 settings
|
||||
# 使用 update_config 持久化更改,它同时处理数据库和 KeyManager
|
||||
await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
|
||||
logger.info(f"密钥 '{key_to_delete}' 已成功删除。")
|
||||
return {"success": True, "message": f"密钥 '{key_to_delete}' 已成功删除。"}
|
||||
else:
|
||||
# 未找到密钥
|
||||
logger.warning(f"尝试删除密钥 '{key_to_delete}',但未找到该密钥。")
|
||||
return {"success": False, "message": f"未找到密钥 '{key_to_delete}'。"}
|
||||
|
||||
@staticmethod
|
||||
async def delete_selected_keys(keys_to_delete: List[str]) -> Dict[str, Any]:
|
||||
"""批量删除选定的API密钥"""
|
||||
if not isinstance(settings.API_KEYS, list):
|
||||
settings.API_KEYS = []
|
||||
|
||||
deleted_count = 0
|
||||
not_found_keys: List[str] = []
|
||||
|
||||
current_api_keys = list(settings.API_KEYS)
|
||||
keys_actually_removed: List[str] = []
|
||||
|
||||
for key_to_del in keys_to_delete:
|
||||
if key_to_del in current_api_keys:
|
||||
current_api_keys.remove(key_to_del)
|
||||
keys_actually_removed.append(key_to_del)
|
||||
deleted_count += 1
|
||||
else:
|
||||
not_found_keys.append(key_to_del)
|
||||
|
||||
if deleted_count > 0:
|
||||
settings.API_KEYS = current_api_keys
|
||||
await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
|
||||
logger.info(
|
||||
f"成功删除 {deleted_count} 个密钥。密钥: {keys_actually_removed}"
|
||||
)
|
||||
message = f"成功删除 {deleted_count} 个密钥。"
|
||||
if not_found_keys:
|
||||
message += f" {len(not_found_keys)} 个密钥未找到: {not_found_keys}。"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"deleted_count": deleted_count,
|
||||
"not_found_keys": not_found_keys,
|
||||
}
|
||||
else:
|
||||
message = "没有密钥被删除。"
|
||||
if not_found_keys:
|
||||
message = f"所有 {len(not_found_keys)} 个指定的密钥均未找到: {not_found_keys}。"
|
||||
elif not keys_to_delete:
|
||||
message = "未指定要删除的密钥。"
|
||||
logger.warning(message)
|
||||
return {
|
||||
"success": False,
|
||||
"message": message,
|
||||
"deleted_count": 0,
|
||||
"not_found_keys": not_found_keys,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def reset_config() -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -124,7 +205,9 @@ class ConfigService:
|
||||
"""
|
||||
# 1. 重新加载配置对象,它应该处理环境变量和 .env 的优先级
|
||||
_reload_settings()
|
||||
logger.info("Settings object reloaded, prioritizing system environment variables then .env file.")
|
||||
logger.info(
|
||||
"Settings object reloaded, prioritizing system environment variables then .env file."
|
||||
)
|
||||
|
||||
# 2. 重置并重新初始化 KeyManager
|
||||
try:
|
||||
@@ -140,6 +223,34 @@ class ConfigService:
|
||||
# 3. 返回更新后的配置
|
||||
return await ConfigService.get_config()
|
||||
|
||||
@staticmethod
|
||||
async def fetch_ui_models() -> List[Dict[str, Any]]:
|
||||
"""获取用于UI显示的模型列表"""
|
||||
try:
|
||||
key_manager = await get_key_manager_instance()
|
||||
model_service = ModelService()
|
||||
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
logger.error("No valid API keys available to fetch model list for UI.")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="No valid API keys available to fetch model list.",
|
||||
)
|
||||
|
||||
models = await model_service.get_gemini_openai_models(api_key)
|
||||
return models
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch models for UI in ConfigService: {e}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch models for UI: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# 重新加载配置的函数
|
||||
def _reload_settings():
|
||||
"""重新加载环境变量并更新配置"""
|
||||
@@ -147,4 +258,4 @@ def _reload_settings():
|
||||
load_dotenv(find_dotenv(), override=True)
|
||||
# 更新现有 settings 对象的属性,而不是新建实例
|
||||
for key, value in ConfigSettings().model_dump().items():
|
||||
setattr(settings, key, value)
|
||||
setattr(settings, key, value)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import datetime
|
||||
import time
|
||||
import re # For potential status code parsing from generic errors
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
import openai
|
||||
from openai import APIStatusError # Import specific error type
|
||||
from openai import APIStatusError
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_embeddings_logger
|
||||
from app.database.services import add_error_log, add_request_log # Import DB logging functions
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
|
||||
logger = get_embeddings_logger()
|
||||
|
||||
@@ -26,7 +26,6 @@ class EmbeddingService:
|
||||
status_code = None
|
||||
response = None
|
||||
error_log_msg = ""
|
||||
# Prepare request message for logging (truncate if list or long string)
|
||||
if isinstance(input_text, list):
|
||||
request_msg_log = {"input_truncated": [str(item)[:100] + "..." if len(str(item)) > 100 else str(item) for item in input_text[:5]]}
|
||||
if len(input_text) > 5:
|
||||
@@ -39,39 +38,36 @@ class EmbeddingService:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
|
||||
response = client.embeddings.create(input=input_text, model=model)
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 OK on success
|
||||
status_code = 200
|
||||
return response
|
||||
except APIStatusError as e:
|
||||
is_success = False
|
||||
status_code = e.status_code
|
||||
error_log_msg = f"OpenAI API error: {e}"
|
||||
logger.error(f"Error creating embedding (APIStatusError): {error_log_msg}")
|
||||
raise e # Re-raise the specific error
|
||||
raise e
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Generic error: {e}"
|
||||
logger.error(f"Error creating embedding (Exception): {error_log_msg}")
|
||||
# Try to parse status code from generic error (less reliable)
|
||||
match = re.search(r"status code (\d+)", str(e))
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default if parsing fails
|
||||
raise e # Re-raise the generic error
|
||||
status_code = 500
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
if not is_success:
|
||||
# Log error to database if it failed
|
||||
await add_error_log(
|
||||
gemini_key=api_key, # Using gemini_key parameter name for consistency
|
||||
model_name=model,
|
||||
error_type="openai-embedding",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request_msg_log
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-embedding",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request_msg_log
|
||||
)
|
||||
# Log request outcome to database regardless of success/failure
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
|
||||
178
app/service/error_log/error_log_service.py
Normal file
178
app/service/error_log/error_log_service.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.connection import database
|
||||
from app.database.models import ErrorLog
|
||||
from app.log.logger import get_error_log_logger
|
||||
|
||||
logger = get_error_log_logger()
|
||||
|
||||
|
||||
async def delete_old_error_logs():
|
||||
"""
|
||||
Deletes error logs older than a specified number of days,
|
||||
based on the AUTO_DELETE_ERROR_LOGS_ENABLED and AUTO_DELETE_ERROR_LOGS_DAYS settings.
|
||||
"""
|
||||
if not settings.AUTO_DELETE_ERROR_LOGS_ENABLED:
|
||||
logger.info("Auto-deletion of error logs is disabled. Skipping.")
|
||||
return
|
||||
|
||||
days_to_keep = settings.AUTO_DELETE_ERROR_LOGS_DAYS
|
||||
if not isinstance(days_to_keep, int) or days_to_keep <= 0:
|
||||
logger.error(
|
||||
f"Invalid AUTO_DELETE_ERROR_LOGS_DAYS value: {days_to_keep}. Must be a positive integer. Skipping deletion."
|
||||
)
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
logger.info(
|
||||
f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."
|
||||
)
|
||||
|
||||
try:
|
||||
if not database.is_connected:
|
||||
await database.connect()
|
||||
logger.info("Database connection established for deleting error logs.")
|
||||
|
||||
# First, count how many logs will be deleted (optional, for logging)
|
||||
count_query = select(func.count(ErrorLog.id)).where(
|
||||
ErrorLog.request_time < cutoff_date
|
||||
)
|
||||
num_logs_to_delete = await database.fetch_val(count_query)
|
||||
|
||||
if num_logs_to_delete == 0:
|
||||
logger.info(
|
||||
"No error logs found older than the specified period. No deletion needed."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Found {num_logs_to_delete} error logs to delete.")
|
||||
|
||||
# Perform the deletion
|
||||
query = delete(ErrorLog).where(ErrorLog.request_time < cutoff_date)
|
||||
await database.execute(query)
|
||||
logger.info(
|
||||
f"Successfully deleted {num_logs_to_delete} error logs older than {days_to_keep} days."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during automatic deletion of error logs: {e}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
async def process_get_error_logs(
|
||||
limit: int,
|
||||
offset: int,
|
||||
key_search: Optional[str],
|
||||
error_search: Optional[str],
|
||||
error_code_search: Optional[str],
|
||||
start_date: Optional[datetime],
|
||||
end_date: Optional[datetime],
|
||||
sort_by: str,
|
||||
sort_order: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理错误日志的检索,支持分页和过滤。
|
||||
"""
|
||||
try:
|
||||
logs_data = await db_services.get_error_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
total_count = await db_services.get_error_logs_count(
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return {"logs": logs_data, "total": total_count}
|
||||
except Exception as e:
|
||||
logger.error(f"Service error in process_get_error_logs: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def process_get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理特定错误日志详细信息的检索。
|
||||
如果未找到,则返回 None。
|
||||
"""
|
||||
try:
|
||||
log_details = await db_services.get_error_log_details(log_id=log_id)
|
||||
return log_details
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_get_error_log_details for ID {log_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
"""
|
||||
按 ID 批量删除错误日志。
|
||||
返回尝试删除的日志数量。
|
||||
"""
|
||||
if not log_ids:
|
||||
return 0
|
||||
try:
|
||||
deleted_count = await db_services.delete_error_logs_by_ids(log_ids)
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_delete_error_logs_by_ids for IDs {log_ids}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_delete_error_log_by_id(log_id: int) -> bool:
|
||||
"""
|
||||
按 ID 删除单个错误日志。
|
||||
如果删除成功(或找到日志并尝试删除),则返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
success = await db_services.delete_error_log_by_id(log_id)
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_delete_error_log_by_id for ID {log_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_delete_all_error_logs() -> int:
|
||||
"""
|
||||
处理删除所有错误日志的请求。
|
||||
返回删除的日志数量。
|
||||
"""
|
||||
try:
|
||||
if not database.is_connected:
|
||||
await database.connect()
|
||||
logger.info("Database connection established for deleting all error logs.")
|
||||
|
||||
deleted_count = await db_services.delete_all_error_logs()
|
||||
logger.info(
|
||||
f"Successfully processed request to delete all error logs. Count: {deleted_count}"
|
||||
)
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_delete_all_error_logs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
1
app/service/files/__init__.py
Normal file
1
app/service/files/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Intentionally empty __init__.py file
|
||||
247
app/service/files/file_upload_handler.py
Normal file
247
app/service/files/file_upload_handler.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
文件上传处理器
|
||||
处理 Google 的可恢复上传协议
|
||||
"""
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import Request, Response, HTTPException
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.log.logger import get_files_logger
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
|
||||
class FileUploadHandler:
|
||||
"""处理文件分块上传"""
|
||||
|
||||
def __init__(self):
|
||||
self.chunk_size = 8 * 1024 * 1024 # 8MB
|
||||
|
||||
async def handle_upload_chunk(
|
||||
self,
|
||||
upload_url: str,
|
||||
request: Request,
|
||||
files_service=None # 添加 files_service 參數
|
||||
) -> Response:
|
||||
"""
|
||||
处理上传分块
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
request: FastAPI 请求对象
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 响应对象
|
||||
"""
|
||||
try:
|
||||
# 获取请求头
|
||||
headers = {}
|
||||
|
||||
# 复制必要的上传头
|
||||
upload_headers = [
|
||||
"x-goog-upload-command",
|
||||
"x-goog-upload-offset",
|
||||
"content-type",
|
||||
"content-length"
|
||||
]
|
||||
|
||||
for header in upload_headers:
|
||||
if header in request.headers:
|
||||
# 转换为正确的格式
|
||||
key = "-".join(word.capitalize() for word in header.split("-"))
|
||||
headers[key] = request.headers[header]
|
||||
|
||||
# 读取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 检查是否是最后一块
|
||||
is_final = "finalize" in headers.get("X-Goog-Upload-Command", "")
|
||||
logger.debug(f"Upload command: {headers.get('X-Goog-Upload-Command', '')}, is_final: {is_final}")
|
||||
|
||||
# 转发到真实的上传 URL
|
||||
async with AsyncClient() as client:
|
||||
response = await client.post(
|
||||
upload_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300.0 # 5分钟超时
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201, 308]:
|
||||
logger.error(f"Upload chunk failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload failed")
|
||||
|
||||
# 如果是最后一块,更新文件状态
|
||||
if is_final and response.status_code in [200, 201]:
|
||||
logger.debug(f"Upload finalized with status {response.status_code}")
|
||||
try:
|
||||
# 解析響應獲取文件信息
|
||||
response_data = response.json()
|
||||
logger.debug(f"Upload complete response data: {response_data}")
|
||||
file_data = response_data.get("file", {})
|
||||
|
||||
# 獲取真實的文件名
|
||||
real_file_name = file_data.get("name")
|
||||
logger.debug(f"Upload response: {response_data}")
|
||||
if real_file_name and files_service:
|
||||
logger.info(f"Upload completed, file name: {real_file_name}")
|
||||
|
||||
# 從會話中獲取信息
|
||||
session_info = await files_service.get_upload_session(upload_url)
|
||||
logger.debug(f"Retrieved session info for {upload_url}: {session_info}")
|
||||
if session_info:
|
||||
# 創建文件記錄
|
||||
now = datetime.now(timezone.utc)
|
||||
expiration_time = now + timedelta(hours=48)
|
||||
|
||||
# 處理過期時間格式(Google 可能返回納秒級精度)
|
||||
expiration_time_str = file_data.get("expirationTime", expiration_time.isoformat() + "Z")
|
||||
# 處理納秒格式:2025-07-11T02:02:52.531916141Z -> 2025-07-11T02:02:52.531916Z
|
||||
if expiration_time_str.endswith("Z"):
|
||||
# 移除 Z
|
||||
expiration_time_str = expiration_time_str[:-1]
|
||||
# 如果有納秒(超過6位小數),截斷到微秒
|
||||
if "." in expiration_time_str:
|
||||
date_part, frac_part = expiration_time_str.rsplit(".", 1)
|
||||
if len(frac_part) > 6:
|
||||
frac_part = frac_part[:6]
|
||||
expiration_time_str = f"{date_part}.{frac_part}"
|
||||
# 添加時區
|
||||
expiration_time_str += "+00:00"
|
||||
|
||||
# 獲取文件狀態(Google 可能返回 PROCESSING)
|
||||
file_state = file_data.get("state", "PROCESSING")
|
||||
logger.debug(f"File state from Google: {file_state}")
|
||||
|
||||
# 將字符串狀態轉換為枚舉
|
||||
if file_state == "ACTIVE":
|
||||
state_enum = FileState.ACTIVE
|
||||
elif file_state == "PROCESSING":
|
||||
state_enum = FileState.PROCESSING
|
||||
elif file_state == "FAILED":
|
||||
state_enum = FileState.FAILED
|
||||
else:
|
||||
logger.warning(f"Unknown file state: {file_state}, defaulting to PROCESSING")
|
||||
state_enum = FileState.PROCESSING
|
||||
|
||||
await db_services.create_file_record(
|
||||
name=real_file_name,
|
||||
mime_type=file_data.get("mimeType", session_info["mime_type"]),
|
||||
size_bytes=int(file_data.get("sizeBytes", session_info["size_bytes"])),
|
||||
api_key=session_info["api_key"],
|
||||
uri=file_data.get("uri", f"{settings.BASE_URL}/{real_file_name}"),
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
expiration_time=datetime.fromisoformat(expiration_time_str),
|
||||
state=state_enum,
|
||||
display_name=file_data.get("displayName", session_info.get("display_name", "")),
|
||||
sha256_hash=file_data.get("sha256Hash"),
|
||||
user_token=session_info["user_token"]
|
||||
)
|
||||
logger.info(f"Created file record: name={real_file_name}, api_key={session_info['api_key'][:8]}...{session_info['api_key'][-4:]}")
|
||||
else:
|
||||
logger.warning(f"No upload session found for URL: {upload_url}")
|
||||
else:
|
||||
logger.warning(f"Missing real_file_name or files_service: real_file_name={real_file_name}, files_service={files_service}")
|
||||
|
||||
# 返回完整的文件信息
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}", exc_info=True)
|
||||
else:
|
||||
logger.debug(f"Upload chunk processed: is_final={is_final}, status={response.status_code}")
|
||||
|
||||
# 返回响应
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
# 确保包含必要的头
|
||||
if response.status_code == 308: # Resume Incomplete
|
||||
if "x-goog-upload-status" not in response_headers:
|
||||
response_headers["x-goog-upload-status"] = "active"
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle upload chunk: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def proxy_upload_request(
|
||||
self,
|
||||
request: Request,
|
||||
upload_url: str,
|
||||
files_service=None
|
||||
) -> Response:
|
||||
"""
|
||||
代理上传请求
|
||||
|
||||
Args:
|
||||
request: FastAPI 请求对象
|
||||
upload_url: 目标上传 URL
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 代理响应
|
||||
"""
|
||||
logger.debug(f"Proxy upload request: {request.method}, {upload_url}")
|
||||
try:
|
||||
# 如果是 GET 请求,返回上传状态
|
||||
if request.method == "GET":
|
||||
return await self._get_upload_status(upload_url)
|
||||
|
||||
# 处理 POST/PUT 请求
|
||||
return await self.handle_upload_chunk(upload_url, request, files_service)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to proxy upload request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _get_upload_status(self, upload_url: str) -> Response:
|
||||
"""
|
||||
获取上传状态
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
|
||||
Returns:
|
||||
Response: 状态响应
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(upload_url)
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upload status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
|
||||
# 单例实例
|
||||
_upload_handler_instance: Optional[FileUploadHandler] = None
|
||||
|
||||
|
||||
def get_upload_handler() -> FileUploadHandler:
|
||||
"""获取上传处理器单例实例"""
|
||||
global _upload_handler_instance
|
||||
if _upload_handler_instance is None:
|
||||
_upload_handler_instance = FileUploadHandler()
|
||||
return _upload_handler_instance
|
||||
498
app/service/files/files_service.py
Normal file
498
app/service/files/files_service.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
文件管理服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from httpx import AsyncClient
|
||||
import asyncio
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.domain.file_models import FileMetadata, ListFilesResponse
|
||||
from fastapi import HTTPException
|
||||
from app.log.logger import get_files_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
# 全局上傳會話存儲
|
||||
_upload_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
_upload_sessions_lock = asyncio.Lock()
|
||||
|
||||
|
||||
class FilesService:
|
||||
"""文件管理服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_client = GeminiApiClient(base_url=settings.BASE_URL)
|
||||
self.key_manager = None
|
||||
|
||||
async def _get_key_manager(self):
|
||||
"""获取 KeyManager 实例"""
|
||||
if not self.key_manager:
|
||||
self.key_manager = await get_key_manager_instance(
|
||||
settings.API_KEYS,
|
||||
settings.VERTEX_API_KEYS
|
||||
)
|
||||
return self.key_manager
|
||||
|
||||
async def initialize_upload(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
body: Optional[bytes],
|
||||
user_token: str,
|
||||
request_host: str = None # 添加請求主機參數
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
初始化文件上传
|
||||
|
||||
Args:
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], Dict[str, str]]: (响应体, 响应头)
|
||||
"""
|
||||
try:
|
||||
# 获取可用的 API key
|
||||
key_manager = await self._get_key_manager()
|
||||
api_key = await key_manager.get_next_key()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No available API keys")
|
||||
|
||||
# 转发请求到真实的 Gemini API
|
||||
async with AsyncClient() as client:
|
||||
# 准备请求头
|
||||
forward_headers = {
|
||||
"X-Goog-Upload-Protocol": headers.get("x-goog-upload-protocol", "resumable"),
|
||||
"X-Goog-Upload-Command": headers.get("x-goog-upload-command", "start"),
|
||||
"Content-Type": headers.get("content-type", "application/json"),
|
||||
}
|
||||
|
||||
# 添加其他必要的头
|
||||
if "x-goog-upload-header-content-length" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Length"] = headers["x-goog-upload-header-content-length"]
|
||||
if "x-goog-upload-header-content-type" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Type"] = headers["x-goog-upload-header-content-type"]
|
||||
|
||||
# 发送请求
|
||||
response = await client.post(
|
||||
"https://generativelanguage.googleapis.com/upload/v1beta/files",
|
||||
headers=forward_headers,
|
||||
content=body,
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Upload initialization failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload initialization failed")
|
||||
|
||||
# 获取上传 URL
|
||||
upload_url = response.headers.get("x-goog-upload-url")
|
||||
if not upload_url:
|
||||
raise HTTPException(status_code=500, detail="No upload URL in response")
|
||||
|
||||
logger.info(f"Original upload URL from Google: {upload_url}")
|
||||
|
||||
|
||||
# 儲存上傳資訊到 headers 中,供後續使用
|
||||
# 不在這裡創建數據庫記錄,等到上傳完成後再創建
|
||||
logger.info(f"Upload initialized with API key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
# 解析响应 - 初始化响应可能是空的
|
||||
response_data = {}
|
||||
|
||||
# 從請求體中解析文件信息(如果有)
|
||||
display_name = ""
|
||||
if body:
|
||||
try:
|
||||
request_data = json.loads(body)
|
||||
display_name = request_data.get("displayName", "")
|
||||
except Exception:
|
||||
pass
|
||||
# 從 upload URL 中提取 upload_id
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(upload_url)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
|
||||
if upload_id:
|
||||
# 儲存上傳會話信息,使用 upload_id 作為 key
|
||||
async with _upload_sessions_lock:
|
||||
_upload_sessions[upload_id] = {
|
||||
"api_key": api_key,
|
||||
"user_token": user_token,
|
||||
"display_name": display_name,
|
||||
"mime_type": headers.get("x-goog-upload-header-content-type", "application/octet-stream"),
|
||||
"size_bytes": int(headers.get("x-goog-upload-header-content-length", "0")),
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"upload_url": upload_url
|
||||
}
|
||||
logger.info(f"Stored upload session for upload_id={upload_id}: api_key={api_key[:8]}...{api_key[-4:]}")
|
||||
logger.debug(f"Total active sessions: {len(_upload_sessions)}")
|
||||
else:
|
||||
logger.warning(f"No upload_id found in upload URL: {upload_url}")
|
||||
|
||||
# 定期清理過期的會話(超過1小時)
|
||||
asyncio.create_task(self._cleanup_expired_sessions())
|
||||
|
||||
# 替換 Google 的 URL 為我們的代理 URL
|
||||
proxy_upload_url = upload_url
|
||||
if request_host:
|
||||
# 原始: https://generativelanguage.googleapis.com/upload/v1beta/files?key=AIzaSyDc...&upload_id=xxx&upload_protocol=resumable
|
||||
# 替換為: http://request-host/upload/v1beta/files?key=sk-123456&upload_id=xxx&upload_protocol=resumable
|
||||
|
||||
# 先替換域名
|
||||
proxy_upload_url = upload_url.replace(
|
||||
"https://generativelanguage.googleapis.com",
|
||||
request_host.rstrip('/')
|
||||
)
|
||||
|
||||
# 再替換 key 參數
|
||||
import re
|
||||
# 匹配 key=xxx 參數
|
||||
key_pattern = r'(\?|&)key=([^&]+)'
|
||||
match = re.search(key_pattern, proxy_upload_url)
|
||||
if match:
|
||||
# 替換為我們的 token
|
||||
proxy_upload_url = proxy_upload_url.replace(
|
||||
f"{match.group(1)}key={match.group(2)}",
|
||||
f"{match.group(1)}key={user_token}"
|
||||
)
|
||||
|
||||
logger.info(f"Replaced upload URL: {upload_url} -> {proxy_upload_url}")
|
||||
|
||||
return response_data, {
|
||||
"X-Goog-Upload-URL": proxy_upload_url,
|
||||
"X-Goog-Upload-Status": "active"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize upload: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理過期的上傳會話"""
|
||||
try:
|
||||
async with _upload_sessions_lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_keys = []
|
||||
for key, session in _upload_sessions.items():
|
||||
if now - session["created_at"] > timedelta(hours=1):
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del _upload_sessions[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired upload sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up upload sessions: {str(e)}")
|
||||
|
||||
async def get_upload_session(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""獲取上傳會話信息(支持 upload_id 或完整 URL)"""
|
||||
async with _upload_sessions_lock:
|
||||
# 先嘗試直接查找
|
||||
session = _upload_sessions.get(key)
|
||||
if session:
|
||||
logger.debug(f"Found session by direct key {key}")
|
||||
return session
|
||||
|
||||
# 如果是 URL,嘗試提取 upload_id
|
||||
if key.startswith("http"):
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(key)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
if upload_id:
|
||||
session = _upload_sessions.get(upload_id)
|
||||
if session:
|
||||
logger.debug(f"Found session by upload_id {upload_id} from URL")
|
||||
return session
|
||||
|
||||
logger.debug(f"No session found for key: {key}")
|
||||
return None
|
||||
|
||||
async def get_file(self, file_name: str, user_token: str) -> FileMetadata:
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
Args:
|
||||
file_name: 文件名称 (格式: files/{file_id})
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
FileMetadata: 文件元数据
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 检查是否过期
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
# 如果是 naive datetime,假设为 UTC
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=404, detail="File has expired")
|
||||
|
||||
# 使用原始 API key 获取文件信息
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get file: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to get file")
|
||||
|
||||
file_data = response.json()
|
||||
|
||||
# 檢查並更新文件狀態
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
if google_state != file_record.get("state", "").value if file_record.get("state") else None:
|
||||
logger.info(f"File state changed from {file_record.get('state')} to {google_state}")
|
||||
# 更新數據庫中的狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return FileMetadata(
|
||||
name=file_data["name"],
|
||||
displayName=file_data.get("displayName"),
|
||||
mimeType=file_data["mimeType"],
|
||||
sizeBytes=str(file_data["sizeBytes"]),
|
||||
createTime=file_data["createTime"],
|
||||
updateTime=file_data["updateTime"],
|
||||
expirationTime=file_data["expirationTime"],
|
||||
sha256Hash=file_data.get("sha256Hash"),
|
||||
uri=file_data["uri"],
|
||||
state=google_state
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
列出文件
|
||||
|
||||
Args:
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记
|
||||
user_token: 用户令牌(可选,如果提供则只返回该用户的文件)
|
||||
|
||||
Returns:
|
||||
ListFilesResponse: 文件列表响应
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_files called with page_size={page_size}, page_token={page_token}")
|
||||
|
||||
# 从数据库获取文件列表
|
||||
files, next_page_token = await db_services.list_file_records(
|
||||
user_token=user_token,
|
||||
page_size=page_size,
|
||||
page_token=page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Database returned {len(files)} files, next_page_token={next_page_token}")
|
||||
|
||||
# 转换为响应格式
|
||||
file_list = []
|
||||
for file_record in files:
|
||||
file_list.append(FileMetadata(
|
||||
name=file_record["name"],
|
||||
displayName=file_record.get("display_name"),
|
||||
mimeType=file_record["mime_type"],
|
||||
sizeBytes=str(file_record["size_bytes"]),
|
||||
createTime=file_record["create_time"].isoformat() + "Z",
|
||||
updateTime=file_record["update_time"].isoformat() + "Z",
|
||||
expirationTime=file_record["expiration_time"].isoformat() + "Z",
|
||||
sha256Hash=file_record.get("sha256_hash"),
|
||||
uri=file_record["uri"],
|
||||
state=file_record["state"].value if file_record.get("state") else "ACTIVE"
|
||||
))
|
||||
|
||||
response = ListFilesResponse(
|
||||
files=file_list,
|
||||
nextPageToken=next_page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Returning response with {len(response.files)} files, nextPageToken={response.nextPageToken}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def delete_file(self, file_name: str, user_token: str) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
file_name: 文件名称
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 使用原始 API key 删除文件
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 204]:
|
||||
logger.error(f"Failed to delete file: {response.status_code} - {response.text}")
|
||||
# 如果 API 删除失败,但文件已过期,仍然删除数据库记录
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to delete file")
|
||||
|
||||
# 删除数据库记录
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def check_file_state(self, file_name: str, api_key: str) -> str:
|
||||
"""
|
||||
檢查並更新文件狀態
|
||||
|
||||
Args:
|
||||
file_name: 文件名稱
|
||||
api_key: API密鑰
|
||||
|
||||
Returns:
|
||||
str: 當前狀態
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to check file state: {response.status_code}")
|
||||
return "UNKNOWN"
|
||||
|
||||
file_data = response.json()
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
|
||||
# 更新數據庫狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
return google_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check file state: {str(e)}")
|
||||
return "UNKNOWN"
|
||||
|
||||
async def cleanup_expired_files(self) -> int:
|
||||
"""
|
||||
清理过期文件
|
||||
|
||||
Returns:
|
||||
int: 清理的文件数量
|
||||
"""
|
||||
try:
|
||||
# 获取过期文件
|
||||
expired_files = await db_services.delete_expired_file_records()
|
||||
|
||||
if not expired_files:
|
||||
return 0
|
||||
|
||||
# 尝试从 Gemini API 删除文件
|
||||
for file_record in expired_files:
|
||||
try:
|
||||
api_key = file_record["api_key"]
|
||||
file_name = file_record["name"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录错误但继续处理其他文件
|
||||
logger.error(f"Failed to delete file {file_record['name']} from API: {str(e)}")
|
||||
|
||||
return len(expired_files)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired files: {str(e)}")
|
||||
return 0
|
||||
|
||||
|
||||
# 单例实例
|
||||
_files_service_instance: Optional[FilesService] = None
|
||||
|
||||
|
||||
async def get_files_service() -> FilesService:
|
||||
"""获取文件服务单例实例"""
|
||||
global _files_service_instance
|
||||
if _files_service_instance is None:
|
||||
_files_service_instance = FilesService()
|
||||
return _files_service_instance
|
||||
@@ -88,7 +88,6 @@ class ImageCreateService:
|
||||
aspect_ratio=self.aspect_ratio,
|
||||
safety_filter_level="BLOCK_LOW_AND_ABOVE",
|
||||
person_generation="ALLOW_ADULT",
|
||||
# language="auto"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -122,6 +121,7 @@ class ImageCreateService:
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -138,7 +138,7 @@ class ImageCreateService:
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"created": int(time.time()), # Current timestamp
|
||||
"created": int(time.time()),
|
||||
"data": images_data,
|
||||
}
|
||||
return response_data
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import Dict
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_key_manager_logger
|
||||
@@ -10,12 +9,19 @@ logger = get_key_manager_logger()
|
||||
|
||||
|
||||
class KeyManager:
|
||||
def __init__(self, api_keys: list):
|
||||
def __init__(self, api_keys: list, vertex_api_keys: list):
|
||||
self.api_keys = api_keys
|
||||
self.vertex_api_keys = vertex_api_keys
|
||||
self.key_cycle = cycle(api_keys)
|
||||
self.vertex_key_cycle = cycle(vertex_api_keys)
|
||||
self.key_cycle_lock = asyncio.Lock()
|
||||
self.vertex_key_cycle_lock = asyncio.Lock()
|
||||
self.failure_count_lock = asyncio.Lock()
|
||||
self.vertex_failure_count_lock = asyncio.Lock()
|
||||
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
||||
self.vertex_key_failure_counts: Dict[str, int] = {
|
||||
key: 0 for key in vertex_api_keys
|
||||
}
|
||||
self.MAX_FAILURES = settings.MAX_FAILURES
|
||||
self.paid_key = settings.PAID_KEY
|
||||
|
||||
@@ -27,17 +33,33 @@ class KeyManager:
|
||||
async with self.key_cycle_lock:
|
||||
return next(self.key_cycle)
|
||||
|
||||
async def get_next_vertex_key(self) -> str:
|
||||
"""获取下一个 Vertex Express API key"""
|
||||
async with self.vertex_key_cycle_lock:
|
||||
return next(self.vertex_key_cycle)
|
||||
|
||||
async def is_key_valid(self, key: str) -> bool:
|
||||
"""检查key是否有效"""
|
||||
async with self.failure_count_lock:
|
||||
return self.key_failure_counts[key] < self.MAX_FAILURES
|
||||
|
||||
async def is_vertex_key_valid(self, key: str) -> bool:
|
||||
"""检查 Vertex key 是否有效"""
|
||||
async with self.vertex_failure_count_lock:
|
||||
return self.vertex_key_failure_counts[key] < self.MAX_FAILURES
|
||||
|
||||
async def reset_failure_counts(self):
|
||||
"""重置所有key的失败计数"""
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
self.key_failure_counts[key] = 0
|
||||
|
||||
|
||||
async def reset_vertex_failure_counts(self):
|
||||
"""重置所有 Vertex key 的失败计数"""
|
||||
async with self.vertex_failure_count_lock:
|
||||
for key in self.vertex_key_failure_counts:
|
||||
self.vertex_key_failure_counts[key] = 0
|
||||
|
||||
async def reset_key_failure_count(self, key: str) -> bool:
|
||||
"""重置指定key的失败计数"""
|
||||
async with self.failure_count_lock:
|
||||
@@ -45,7 +67,21 @@ class KeyManager:
|
||||
self.key_failure_counts[key] = 0
|
||||
logger.info(f"Reset failure count for key: {key}")
|
||||
return True
|
||||
logger.warning(f"Attempt to reset failure count for non-existent key: {key}")
|
||||
logger.warning(
|
||||
f"Attempt to reset failure count for non-existent key: {key}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def reset_vertex_key_failure_count(self, key: str) -> bool:
|
||||
"""重置指定 Vertex key 的失败计数"""
|
||||
async with self.vertex_failure_count_lock:
|
||||
if key in self.vertex_key_failure_counts:
|
||||
self.vertex_key_failure_counts[key] = 0
|
||||
logger.info(f"Reset failure count for Vertex key: {key}")
|
||||
return True
|
||||
logger.warning(
|
||||
f"Attempt to reset failure count for non-existent Vertex key: {key}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def get_next_working_key(self) -> str:
|
||||
@@ -59,10 +95,22 @@ class KeyManager:
|
||||
|
||||
current_key = await self.get_next_key()
|
||||
if current_key == initial_key:
|
||||
# await self.reset_failure_counts() 取消重置
|
||||
return current_key
|
||||
|
||||
async def handle_api_failure(self, api_key: str,retries: int) -> str:
|
||||
async def get_next_working_vertex_key(self) -> str:
|
||||
"""获取下一可用的 Vertex Express API key"""
|
||||
initial_key = await self.get_next_vertex_key()
|
||||
current_key = initial_key
|
||||
|
||||
while True:
|
||||
if await self.is_vertex_key_valid(current_key):
|
||||
return current_key
|
||||
|
||||
current_key = await self.get_next_vertex_key()
|
||||
if current_key == initial_key:
|
||||
return current_key
|
||||
|
||||
async def handle_api_failure(self, api_key: str, retries: int) -> str:
|
||||
"""处理API调用失败"""
|
||||
async with self.failure_count_lock:
|
||||
self.key_failure_counts[api_key] += 1
|
||||
@@ -72,13 +120,26 @@ class KeyManager:
|
||||
)
|
||||
if retries < settings.MAX_RETRIES:
|
||||
return await self.get_next_working_key()
|
||||
else:
|
||||
else:
|
||||
return ""
|
||||
|
||||
async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str:
|
||||
"""处理 Vertex Express API 调用失败"""
|
||||
async with self.vertex_failure_count_lock:
|
||||
self.vertex_key_failure_counts[api_key] += 1
|
||||
if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"Vertex Express API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
|
||||
def get_fail_count(self, key: str) -> int:
|
||||
"""获取指定密钥的失败次数"""
|
||||
return self.key_failure_counts.get(key, 0)
|
||||
|
||||
def get_vertex_fail_count(self, key: str) -> int:
|
||||
"""获取指定 Vertex 密钥的失败次数"""
|
||||
return self.vertex_key_failure_counts.get(key, 0)
|
||||
|
||||
async def get_keys_by_status(self) -> dict:
|
||||
"""获取分类后的API key列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
@@ -94,40 +155,309 @@ class KeyManager:
|
||||
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||
|
||||
async def get_vertex_keys_by_status(self) -> dict:
|
||||
"""获取分类后的 Vertex Express API key 列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
invalid_keys = {}
|
||||
|
||||
async with self.vertex_failure_count_lock:
|
||||
for key in self.vertex_api_keys:
|
||||
fail_count = self.vertex_key_failure_counts[key]
|
||||
if fail_count < self.MAX_FAILURES:
|
||||
valid_keys[key] = fail_count
|
||||
else:
|
||||
invalid_keys[key] = fail_count
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||
|
||||
async def get_first_valid_key(self) -> str:
|
||||
"""获取第一个有效的API key"""
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
if self.key_failure_counts[key] < self.MAX_FAILURES:
|
||||
return key
|
||||
if self.api_keys:
|
||||
return self.api_keys[0]
|
||||
if not self.api_keys:
|
||||
logger.warning("API key list is empty, cannot get first valid key.")
|
||||
return ""
|
||||
return self.api_keys[0]
|
||||
|
||||
|
||||
_singleton_instance = None
|
||||
_singleton_lock = asyncio.Lock()
|
||||
_preserved_failure_counts: Union[Dict[str, int], None] = None
|
||||
_preserved_vertex_failure_counts: Union[Dict[str, int], None] = None
|
||||
_preserved_old_api_keys_for_reset: Union[list, None] = None
|
||||
_preserved_vertex_old_api_keys_for_reset: Union[list, None] = None
|
||||
_preserved_next_key_in_cycle: Union[str, None] = None
|
||||
_preserved_vertex_next_key_in_cycle: Union[str, None] = None
|
||||
|
||||
|
||||
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
|
||||
async def get_key_manager_instance(
|
||||
api_keys: list = None, vertex_api_keys: list = None
|
||||
) -> KeyManager:
|
||||
"""
|
||||
获取 KeyManager 单例实例。
|
||||
|
||||
如果尚未创建实例,将使用提供的 api_keys 初始化 KeyManager。
|
||||
如果尚未创建实例,将使用提供的 api_keys,vertex_api_keys 初始化 KeyManager。
|
||||
如果已创建实例,则忽略 api_keys 参数,返回现有单例。
|
||||
如果在重置后调用,会尝试恢复之前的状态(失败计数、循环位置)。
|
||||
"""
|
||||
global _singleton_instance
|
||||
global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle
|
||||
|
||||
async with _singleton_lock:
|
||||
if _singleton_instance is None:
|
||||
if api_keys is None:
|
||||
raise ValueError("API keys are required to initialize the KeyManager")
|
||||
_singleton_instance = KeyManager(api_keys)
|
||||
logger.info("KeyManager instance created.")
|
||||
raise ValueError(
|
||||
"API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
)
|
||||
if vertex_api_keys is None:
|
||||
raise ValueError(
|
||||
"Vertex Express API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
)
|
||||
|
||||
if not api_keys:
|
||||
logger.warning(
|
||||
"Initializing KeyManager with an empty list of API keys."
|
||||
)
|
||||
if not vertex_api_keys:
|
||||
logger.warning(
|
||||
"Initializing KeyManager with an empty list of Vertex Express API keys."
|
||||
)
|
||||
|
||||
_singleton_instance = KeyManager(api_keys, vertex_api_keys)
|
||||
logger.info(
|
||||
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex Express API keys."
|
||||
)
|
||||
|
||||
# 1. 恢复失败计数
|
||||
if _preserved_failure_counts:
|
||||
current_failure_counts = {
|
||||
key: 0 for key in _singleton_instance.api_keys
|
||||
}
|
||||
for key, count in _preserved_failure_counts.items():
|
||||
if key in current_failure_counts:
|
||||
current_failure_counts[key] = count
|
||||
_singleton_instance.key_failure_counts = current_failure_counts
|
||||
logger.info("Inherited failure counts for applicable keys.")
|
||||
_preserved_failure_counts = None
|
||||
|
||||
if _preserved_vertex_failure_counts:
|
||||
current_vertex_failure_counts = {
|
||||
key: 0 for key in _singleton_instance.vertex_api_keys
|
||||
}
|
||||
for key, count in _preserved_vertex_failure_counts.items():
|
||||
if key in current_vertex_failure_counts:
|
||||
current_vertex_failure_counts[key] = count
|
||||
_singleton_instance.vertex_key_failure_counts = (
|
||||
current_vertex_failure_counts
|
||||
)
|
||||
logger.info("Inherited failure counts for applicable Vertex keys.")
|
||||
_preserved_vertex_failure_counts = None
|
||||
|
||||
# 2. 调整 key_cycle 的起始点
|
||||
start_key_for_new_cycle = None
|
||||
if (
|
||||
_preserved_old_api_keys_for_reset
|
||||
and _preserved_next_key_in_cycle
|
||||
and _singleton_instance.api_keys
|
||||
):
|
||||
try:
|
||||
start_idx_in_old = _preserved_old_api_keys_for_reset.index(
|
||||
_preserved_next_key_in_cycle
|
||||
)
|
||||
|
||||
for i in range(len(_preserved_old_api_keys_for_reset)):
|
||||
current_old_key_idx = (start_idx_in_old + i) % len(
|
||||
_preserved_old_api_keys_for_reset
|
||||
)
|
||||
key_candidate = _preserved_old_api_keys_for_reset[
|
||||
current_old_key_idx
|
||||
]
|
||||
if key_candidate in _singleton_instance.api_keys:
|
||||
start_key_for_new_cycle = key_candidate
|
||||
break
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Preserved next key '{_preserved_next_key_in_cycle}' not found in preserved old API keys. "
|
||||
"New cycle will start from the beginning of the new list."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error determining start key for new cycle from preserved state: {e}. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
|
||||
if start_key_for_new_cycle and _singleton_instance.api_keys:
|
||||
try:
|
||||
target_idx = _singleton_instance.api_keys.index(
|
||||
start_key_for_new_cycle
|
||||
)
|
||||
for _ in range(target_idx):
|
||||
next(_singleton_instance.key_cycle)
|
||||
logger.info(
|
||||
f"Key cycle in new instance advanced. Next call to get_next_key() will yield: {start_key_for_new_cycle}"
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Determined start key '{start_key_for_new_cycle}' not found in new API keys during cycle advancement. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
except StopIteration:
|
||||
logger.error(
|
||||
"StopIteration while advancing key cycle, implies empty new API key list previously missed."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error advancing new key cycle: {e}. Cycle will start from beginning."
|
||||
)
|
||||
else:
|
||||
if _singleton_instance.api_keys:
|
||||
logger.info(
|
||||
"New key cycle will start from the beginning of the new API key list (no specific start key determined or needed)."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"New key cycle not applicable as the new API key list is empty."
|
||||
)
|
||||
|
||||
# 清理所有保存的状态
|
||||
_preserved_old_api_keys_for_reset = None
|
||||
_preserved_next_key_in_cycle = None
|
||||
|
||||
# 3. 调整 vertex_key_cycle 的起始点
|
||||
start_key_for_new_vertex_cycle = None
|
||||
if (
|
||||
_preserved_vertex_old_api_keys_for_reset
|
||||
and _preserved_vertex_next_key_in_cycle
|
||||
and _singleton_instance.vertex_api_keys
|
||||
):
|
||||
try:
|
||||
start_idx_in_old = _preserved_vertex_old_api_keys_for_reset.index(
|
||||
_preserved_vertex_next_key_in_cycle
|
||||
)
|
||||
|
||||
for i in range(len(_preserved_vertex_old_api_keys_for_reset)):
|
||||
current_old_key_idx = (start_idx_in_old + i) % len(
|
||||
_preserved_vertex_old_api_keys_for_reset
|
||||
)
|
||||
key_candidate = _preserved_vertex_old_api_keys_for_reset[
|
||||
current_old_key_idx
|
||||
]
|
||||
if key_candidate in _singleton_instance.vertex_api_keys:
|
||||
start_key_for_new_vertex_cycle = key_candidate
|
||||
break
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex Express API keys. "
|
||||
"New cycle will start from the beginning of the new list."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error determining start key for new Vertex key cycle from preserved state: {e}. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
|
||||
if start_key_for_new_vertex_cycle and _singleton_instance.vertex_api_keys:
|
||||
try:
|
||||
target_idx = _singleton_instance.vertex_api_keys.index(
|
||||
start_key_for_new_vertex_cycle
|
||||
)
|
||||
for _ in range(target_idx):
|
||||
next(_singleton_instance.vertex_key_cycle)
|
||||
logger.info(
|
||||
f"Vertex key cycle in new instance advanced. Next call to get_next_vertex_key() will yield: {start_key_for_new_vertex_cycle}"
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex Express API keys during cycle advancement. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
except StopIteration:
|
||||
logger.error(
|
||||
"StopIteration while advancing Vertex key cycle, implies empty new Vertex Express API key list previously missed."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error advancing new Vertex key cycle: {e}. Cycle will start from beginning."
|
||||
)
|
||||
else:
|
||||
if _singleton_instance.vertex_api_keys:
|
||||
logger.info(
|
||||
"New Vertex key cycle will start from the beginning of the new Vertex Express API key list (no specific start key determined or needed)."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"New Vertex key cycle not applicable as the new Vertex Express API key list is empty."
|
||||
)
|
||||
|
||||
# 清理所有保存的状态
|
||||
_preserved_vertex_old_api_keys_for_reset = None
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
|
||||
return _singleton_instance
|
||||
|
||||
|
||||
|
||||
async def reset_key_manager_instance():
|
||||
"""重置 KeyManager 单例实例"""
|
||||
global _singleton_instance
|
||||
"""
|
||||
重置 KeyManager 单例实例。
|
||||
将保存当前实例的状态(失败计数、旧 API keys、下一个 key 提示)
|
||||
以供下一次 get_key_manager_instance 调用时恢复。
|
||||
"""
|
||||
global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle
|
||||
async with _singleton_lock:
|
||||
if _singleton_instance:
|
||||
# 1. 保存失败计数
|
||||
_preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
|
||||
_preserved_vertex_failure_counts = (
|
||||
_singleton_instance.vertex_key_failure_counts.copy()
|
||||
)
|
||||
|
||||
# 2. 保存旧的 API keys 列表
|
||||
_preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
|
||||
_preserved_vertex_old_api_keys_for_reset = (
|
||||
_singleton_instance.vertex_api_keys.copy()
|
||||
)
|
||||
|
||||
# 3. 保存 key_cycle 的下一个 key 提示
|
||||
try:
|
||||
if _singleton_instance.api_keys:
|
||||
_preserved_next_key_in_cycle = (
|
||||
await _singleton_instance.get_next_key()
|
||||
)
|
||||
else:
|
||||
_preserved_next_key_in_cycle = None
|
||||
except StopIteration:
|
||||
logger.warning(
|
||||
"Could not preserve next key hint: key cycle was empty or exhausted in old instance."
|
||||
)
|
||||
_preserved_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_next_key_in_cycle = None
|
||||
|
||||
# 4. 保存 vertex_key_cycle 的下一个 key 提示
|
||||
try:
|
||||
if _singleton_instance.vertex_api_keys:
|
||||
_preserved_vertex_next_key_in_cycle = (
|
||||
await _singleton_instance.get_next_vertex_key()
|
||||
)
|
||||
else:
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
except StopIteration:
|
||||
logger.warning(
|
||||
"Could not preserve next key hint: Vertex key cycle was empty or exhausted in old instance."
|
||||
)
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
|
||||
_singleton_instance = None
|
||||
logger.info("KeyManager instance reset.")
|
||||
logger.info(
|
||||
"KeyManager instance has been reset. State (failure counts, old keys, next key hint) preserved for next instantiation."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"KeyManager instance was not set (or already reset), no reset action performed."
|
||||
)
|
||||
|
||||
@@ -1,50 +1,46 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_model_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
|
||||
logger = get_model_logger()
|
||||
|
||||
|
||||
class ModelService:
|
||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
url = f"{settings.BASE_URL}/models?key={api_key}"
|
||||
async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
api_client = GeminiApiClient(base_url=settings.BASE_URL)
|
||||
gemini_models = await api_client.get_models(api_key)
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
|
||||
filtered_models_list = []
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
if model_id not in settings.FILTERED_MODELS:
|
||||
filtered_models_list.append(model)
|
||||
else:
|
||||
logger.debug(f"Filtered out model: {model_id}")
|
||||
|
||||
gemini_models["models"] = filtered_models_list
|
||||
return gemini_models
|
||||
else:
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(response.text)
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
if gemini_models is None:
|
||||
logger.error("从 API 客户端获取模型列表失败。")
|
||||
return None
|
||||
|
||||
def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
gemini_models = self.get_gemini_models(api_key)
|
||||
return self.convert_to_openai_models_format(gemini_models)
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
filtered_models_list = []
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
if model_id not in settings.FILTERED_MODELS:
|
||||
filtered_models_list.append(model)
|
||||
else:
|
||||
logger.debug(f"Filtered out model: {model_id}")
|
||||
|
||||
gemini_models["models"] = filtered_models_list
|
||||
return gemini_models
|
||||
except Exception as e:
|
||||
logger.error(f"处理模型列表时出错: {e}")
|
||||
return None
|
||||
|
||||
def convert_to_openai_models_format(
|
||||
async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取 Gemini 模型并转换为 OpenAI 格式"""
|
||||
gemini_models = await self.get_gemini_models(api_key)
|
||||
if gemini_models is None:
|
||||
return None
|
||||
|
||||
return await self.convert_to_openai_models_format(gemini_models)
|
||||
|
||||
async def convert_to_openai_models_format(
|
||||
self, gemini_models: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
openai_format = {"object": "list", "data": [], "success": True}
|
||||
@@ -81,7 +77,7 @@ class ModelService:
|
||||
openai_format["data"].append(image_model)
|
||||
return openai_format
|
||||
|
||||
def check_model_support(self, model: str) -> bool:
|
||||
async def check_model_support(self, model: str) -> bool:
|
||||
if not model or not isinstance(model, str):
|
||||
return False
|
||||
|
||||
|
||||
190
app/service/openai_compatiable/openai_compatiable_service.py
Normal file
190
app/service/openai_compatiable/openai_compatiable_service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.service.client.api_client import OpenaiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
class OpenAICompatiableService:
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.key_manager = key_manager
|
||||
self.base_url = base_url
|
||||
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
return await self.api_client.get_models(api_key)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, request_dict, api_key)
|
||||
return await self._handle_normal_completion(request.model, request_dict, api_key)
|
||||
|
||||
async def generate_images(
|
||||
self,
|
||||
request: ImageGenerationRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成图片"""
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
api_key = settings.PAID_KEY
|
||||
return await self.api_client.generate_images(request_dict, api_key)
|
||||
|
||||
async def create_embeddings(
|
||||
self,
|
||||
input_text: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建嵌入"""
|
||||
return await self.api_client.create_embeddings(input_text, model, api_key)
|
||||
|
||||
async def _handle_normal_completion(
|
||||
self, model: str, request: dict, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
try:
|
||||
response = await self.api_client.generate_content(request, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-compatiable-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: dict, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
|
||||
while retries < max_retries:
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, current_attempt_key
|
||||
):
|
||||
if line.startswith("data:"):
|
||||
# print(line)
|
||||
yield line + "\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
is_success = True
|
||||
status_code = 200
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-compatiable-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
|
||||
if self.key_manager:
|
||||
api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
if not is_success and retries >= max_retries:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
50
app/service/request_log/request_log_service.py
Normal file
50
app/service/request_log/request_log_service.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Service for request log operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from app.database.connection import database
|
||||
from app.config.config import settings
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import get_request_log_logger
|
||||
|
||||
logger = get_request_log_logger()
|
||||
|
||||
|
||||
async def delete_old_request_logs_task():
|
||||
"""
|
||||
定时删除旧的请求日志。
|
||||
"""
|
||||
if not settings.AUTO_DELETE_REQUEST_LOGS_ENABLED:
|
||||
logger.info(
|
||||
"Auto-delete for request logs is disabled by settings. Skipping task."
|
||||
)
|
||||
return
|
||||
|
||||
days_to_keep = settings.AUTO_DELETE_REQUEST_LOGS_DAYS
|
||||
logger.info(
|
||||
f"Starting scheduled task to delete old request logs older than {days_to_keep} days."
|
||||
)
|
||||
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)
|
||||
|
||||
if not database.is_connected:
|
||||
logger.info("Connecting to database for request log deletion.")
|
||||
await database.connect()
|
||||
|
||||
result = await database.execute(query)
|
||||
logger.info(
|
||||
f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result.rowcount if result else 'N/A'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled request log deletion: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
255
app/service/stats/stats_service.py
Normal file
255
app/service/stats/stats_service.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# app/service/stats_service.py
|
||||
|
||||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
|
||||
from app.database.connection import database
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import get_stats_logger
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
|
||||
class StatsService:
|
||||
"""Service class for handling statistics related operations."""
|
||||
|
||||
async def get_calls_in_last_seconds(self, seconds: int) -> dict[str, int]:
|
||||
"""获取过去 N 秒内的调用次数 (总数、成功、失败)"""
|
||||
try:
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(seconds=seconds)
|
||||
query = select(
|
||||
func.count(RequestLog.id).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
and_(
|
||||
RequestLog.status_code >= 200,
|
||||
RequestLog.status_code < 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("success"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
or_(
|
||||
RequestLog.status_code < 200,
|
||||
RequestLog.status_code >= 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(RequestLog.status_code is None, 1),
|
||||
else_=0,
|
||||
)
|
||||
).label("failure"),
|
||||
).where(RequestLog.request_time >= cutoff_time)
|
||||
result = await database.fetch_one(query)
|
||||
if result:
|
||||
return {
|
||||
"total": result["total"] or 0,
|
||||
"success": result["success"] or 0,
|
||||
"failure": result["failure"] or 0,
|
||||
}
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get calls in last {seconds} seconds: {e}")
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
|
||||
async def get_calls_in_last_minutes(self, minutes: int) -> dict[str, int]:
|
||||
"""获取过去 N 分钟内的调用次数 (总数、成功、失败)"""
|
||||
return await self.get_calls_in_last_seconds(minutes * 60)
|
||||
|
||||
async def get_calls_in_last_hours(self, hours: int) -> dict[str, int]:
|
||||
"""获取过去 N 小时内的调用次数 (总数、成功、失败)"""
|
||||
return await self.get_calls_in_last_seconds(hours * 3600)
|
||||
|
||||
async def get_calls_in_current_month(self) -> dict[str, int]:
|
||||
"""获取当前自然月内的调用次数 (总数、成功、失败)"""
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
start_of_month = now.replace(
|
||||
day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
query = select(
|
||||
func.count(RequestLog.id).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
and_(
|
||||
RequestLog.status_code >= 200,
|
||||
RequestLog.status_code < 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("success"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
or_(
|
||||
RequestLog.status_code < 200,
|
||||
RequestLog.status_code >= 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(RequestLog.status_code is None, 1),
|
||||
else_=0,
|
||||
)
|
||||
).label("failure"),
|
||||
).where(RequestLog.request_time >= start_of_month)
|
||||
result = await database.fetch_one(query)
|
||||
if result:
|
||||
return {
|
||||
"total": result["total"] or 0,
|
||||
"success": result["success"] or 0,
|
||||
"failure": result["failure"] or 0,
|
||||
}
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get calls in current month: {e}")
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
|
||||
async def get_api_usage_stats(self) -> dict:
|
||||
"""获取所有需要的 API 使用统计数据 (总数、成功、失败)"""
|
||||
try:
|
||||
stats_1m = await self.get_calls_in_last_minutes(1)
|
||||
stats_1h = await self.get_calls_in_last_hours(1)
|
||||
stats_24h = await self.get_calls_in_last_hours(24)
|
||||
stats_month = await self.get_calls_in_current_month()
|
||||
|
||||
return {
|
||||
"calls_1m": stats_1m,
|
||||
"calls_1h": stats_1h,
|
||||
"calls_24h": stats_24h,
|
||||
"calls_month": stats_month,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get API usage stats: {e}")
|
||||
default_stat = {"total": 0, "success": 0, "failure": 0}
|
||||
return {
|
||||
"calls_1m": default_stat.copy(),
|
||||
"calls_1h": default_stat.copy(),
|
||||
"calls_24h": default_stat.copy(),
|
||||
"calls_month": default_stat.copy(),
|
||||
}
|
||||
|
||||
async def get_api_call_details(self, period: str) -> list[dict]:
|
||||
"""
|
||||
获取指定时间段内的 API 调用详情
|
||||
|
||||
Args:
|
||||
period: 时间段标识 ('1m', '1h', '24h')
|
||||
|
||||
Returns:
|
||||
包含调用详情的字典列表,每个字典包含 timestamp, key, model, status
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 period 无效
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
if period == "1m":
|
||||
start_time = now - datetime.timedelta(minutes=1)
|
||||
elif period == "1h":
|
||||
start_time = now - datetime.timedelta(hours=1)
|
||||
elif period == "24h":
|
||||
start_time = now - datetime.timedelta(hours=24)
|
||||
else:
|
||||
raise ValueError(f"无效的时间段标识: {period}")
|
||||
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.request_time.label("timestamp"),
|
||||
RequestLog.api_key.label("key"),
|
||||
RequestLog.model_name.label("model"),
|
||||
RequestLog.status_code,
|
||||
)
|
||||
.where(RequestLog.request_time >= start_time)
|
||||
.order_by(RequestLog.request_time.desc())
|
||||
)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
details = []
|
||||
for row in results:
|
||||
status = "failure"
|
||||
if row["status_code"] is not None:
|
||||
status = "success" if 200 <= row["status_code"] < 300 else "failure"
|
||||
details.append(
|
||||
{
|
||||
"timestamp": row[
|
||||
"timestamp"
|
||||
].isoformat(),
|
||||
"key": row["key"],
|
||||
"model": row["model"],
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Retrieved {len(details)} API call details for period '{period}'"
|
||||
)
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get API call details for period '{period}': {e}")
|
||||
raise
|
||||
|
||||
async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]:
|
||||
"""
|
||||
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
|
||||
|
||||
Args:
|
||||
key: 要查询的 API 密钥。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是模型名称,值是调用次数。
|
||||
如果查询出错或没有找到记录,可能返回 None 或空字典。
|
||||
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
||||
"""
|
||||
logger.info(
|
||||
f"Fetching usage details for key ending in ...{key[-4:]} for the last 24h."
|
||||
)
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=24)
|
||||
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.model_name, func.count(
|
||||
RequestLog.id).label("call_count")
|
||||
)
|
||||
.where(
|
||||
RequestLog.api_key == key,
|
||||
RequestLog.request_time >= cutoff_time,
|
||||
RequestLog.model_name.isnot(None),
|
||||
)
|
||||
.group_by(RequestLog.model_name)
|
||||
.order_by(func.count(RequestLog.id).desc())
|
||||
)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
if not results:
|
||||
logger.info(
|
||||
f"No usage details found for key ending in ...{key[-4:]} in the last 24h."
|
||||
)
|
||||
return {}
|
||||
|
||||
usage_details = {row["model_name"]: row["call_count"]
|
||||
for row in results}
|
||||
logger.info(
|
||||
f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
|
||||
)
|
||||
return usage_details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get key usage details for key ending in ...{key[-4:]}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
@@ -1,174 +0,0 @@
|
||||
# app/service/stats_service.py
|
||||
|
||||
import datetime
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from app.database.connection import database
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import get_stats_logger
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
|
||||
class StatsService:
|
||||
"""Service class for handling statistics related operations."""
|
||||
|
||||
async def get_calls_in_last_seconds(self, seconds: int) -> int:
|
||||
"""获取过去 N 秒内的调用次数 (包括成功和失败)"""
|
||||
try:
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(seconds=seconds)
|
||||
query = select(func.count(RequestLog.id)).where(
|
||||
RequestLog.request_time >= cutoff_time
|
||||
)
|
||||
count_result = await database.fetch_one(query)
|
||||
return count_result[0] if count_result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get calls in last {seconds} seconds: {e}")
|
||||
return 0 # Return 0 on error
|
||||
|
||||
async def get_calls_in_last_minutes(self, minutes: int) -> int:
|
||||
"""获取过去 N 分钟内的调用次数 (包括成功和失败)"""
|
||||
return await self.get_calls_in_last_seconds(minutes * 60)
|
||||
|
||||
async def get_calls_in_last_hours(self, hours: int) -> int:
|
||||
"""获取过去 N 小时内的调用次数 (包括成功和失败)"""
|
||||
return await self.get_calls_in_last_seconds(hours * 3600)
|
||||
|
||||
async def get_calls_in_current_month(self) -> int:
|
||||
"""获取当前自然月内的调用次数 (包括成功和失败)"""
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
query = select(func.count(RequestLog.id)).where(
|
||||
RequestLog.request_time >= start_of_month
|
||||
)
|
||||
count_result = await database.fetch_one(query)
|
||||
return count_result[0] if count_result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get calls in current month: {e}")
|
||||
return 0 # Return 0 on error
|
||||
|
||||
async def get_api_usage_stats(self) -> dict:
|
||||
"""获取所有需要的 API 使用统计数据"""
|
||||
try:
|
||||
calls_1m = await self.get_calls_in_last_minutes(1)
|
||||
calls_1h = await self.get_calls_in_last_hours(1)
|
||||
calls_24h = await self.get_calls_in_last_hours(24)
|
||||
calls_month = await self.get_calls_in_current_month()
|
||||
|
||||
return {
|
||||
"calls_1m": calls_1m,
|
||||
"calls_1h": calls_1h,
|
||||
"calls_24h": calls_24h,
|
||||
"calls_month": calls_month,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get API usage stats: {e}")
|
||||
# Return default values on error
|
||||
return {
|
||||
"calls_1m": 0,
|
||||
"calls_1h": 0,
|
||||
"calls_24h": 0,
|
||||
"calls_month": 0,
|
||||
}
|
||||
|
||||
|
||||
async def get_api_call_details(self, period: str) -> list[dict]:
|
||||
"""
|
||||
获取指定时间段内的 API 调用详情
|
||||
|
||||
Args:
|
||||
period: 时间段标识 ('1m', '1h', '24h')
|
||||
|
||||
Returns:
|
||||
包含调用详情的字典列表,每个字典包含 timestamp, key, model, status
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 period 无效
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
if period == '1m':
|
||||
start_time = now - datetime.timedelta(minutes=1)
|
||||
elif period == '1h':
|
||||
start_time = now - datetime.timedelta(hours=1)
|
||||
elif period == '24h':
|
||||
start_time = now - datetime.timedelta(hours=24)
|
||||
else:
|
||||
raise ValueError(f"无效的时间段标识: {period}")
|
||||
|
||||
try:
|
||||
query = select(
|
||||
RequestLog.request_time.label("timestamp"),
|
||||
RequestLog.api_key.label("key"),
|
||||
RequestLog.model_name.label("model"),
|
||||
RequestLog.status_code # We might need to map this to 'success'/'failure' later
|
||||
).where(
|
||||
RequestLog.request_time >= start_time
|
||||
).order_by(RequestLog.request_time.desc()) # Order by most recent first
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
# Convert results to list of dicts and map status_code
|
||||
details = []
|
||||
for row in results:
|
||||
status = 'failure' # 默认状态为 failure,如果 status_code 有效且在 200-299 范围内则更新为 success
|
||||
if row['status_code'] is not None: # 检查 status_code 是否为空
|
||||
status = 'success' if 200 <= row['status_code'] < 300 else 'failure'
|
||||
details.append({
|
||||
"timestamp": row['timestamp'].isoformat(), # Use ISO format for JS compatibility
|
||||
"key": row['key'],
|
||||
"model": row['model'],
|
||||
"status": status
|
||||
})
|
||||
logger.info(f"Retrieved {len(details)} API call details for period '{period}'")
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get API call details for period '{period}': {e}")
|
||||
# Re-raise the exception to be handled by the route
|
||||
raise
|
||||
|
||||
async def get_key_usage_details_last_24h(self, key: str) -> dict | None:
|
||||
"""
|
||||
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
|
||||
|
||||
Args:
|
||||
key: 要查询的 API 密钥。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是模型名称,值是调用次数。
|
||||
如果查询出错或没有找到记录,可能返回 None 或空字典。
|
||||
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
||||
"""
|
||||
logger.info(f"Fetching usage details for key ending in ...{key[-4:]} for the last 24h.")
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=24)
|
||||
|
||||
try:
|
||||
query = select(
|
||||
RequestLog.model_name,
|
||||
func.count(RequestLog.id).label("call_count")
|
||||
).where(
|
||||
RequestLog.api_key == key,
|
||||
RequestLog.request_time >= cutoff_time,
|
||||
RequestLog.model_name.isnot(None) # Ensure model_name is not null
|
||||
).group_by(
|
||||
RequestLog.model_name
|
||||
).order_by(
|
||||
func.count(RequestLog.id).desc() # Order by count descending
|
||||
)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
if not results:
|
||||
logger.info(f"No usage details found for key ending in ...{key[-4:]} in the last 24h.")
|
||||
return {} # Return empty dict if no records found
|
||||
|
||||
usage_details = {row['model_name']: row['call_count'] for row in results}
|
||||
logger.info(f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}")
|
||||
return usage_details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key usage details for key ending in ...{key[-4:]}: {e}", exc_info=True)
|
||||
# Depending on requirements, you might return None or raise the exception
|
||||
# Raising allows the route handler to return a 500 error.
|
||||
raise # Re-raise the exception
|
||||
363
app/service/tts/native/README.md
Normal file
363
app/service/tts/native/README.md
Normal file
@@ -0,0 +1,363 @@
|
||||
# 原生Gemini TTS功能
|
||||
|
||||
这个模块为Gemini Balance项目添加了原生Gemini TTS(Text-to-Speech)功能,支持单人和多人语音合成,采用智能检测和继承模式设计,保持与原始代码的完全兼容性。
|
||||
|
||||
## 🎯 设计原则
|
||||
|
||||
- **智能检测**:自动检测所有原生Gemini TTS格式的请求(包含responseModalities和speechConfig)
|
||||
- **继承而非修改**:所有扩展都继承自原始类,不修改源码
|
||||
- **完全兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **动态模型选择**:支持用户在请求URL中指定不同的TTS模型
|
||||
- **自动回退**:原生TTS处理失败时自动回退到标准服务
|
||||
- **完整日志记录**:包含请求日志、错误日志和性能监控
|
||||
- **易于维护**:更新原始代码时不会产生冲突
|
||||
|
||||
## 📁 文件结构
|
||||
|
||||
```
|
||||
app/service/tts/
|
||||
├── tts_service.py # 原有的OpenAI兼容TTS服务
|
||||
└── native/ # 原生Gemini TTS扩展
|
||||
├── __init__.py # 模块初始化
|
||||
├── README.md # 使用说明(本文件)
|
||||
├── tts_models.py # TTS数据模型(继承自原始模型)
|
||||
├── tts_response_handler.py # TTS响应处理器(继承自原始处理器)
|
||||
├── tts_chat_service.py # TTS聊天服务(继承自原始服务)
|
||||
└── tts_routes.py # TTS路由扩展和依赖注入
|
||||
```
|
||||
|
||||
## 🚀 原生Gemini TTS功能
|
||||
|
||||
### 智能检测机制(当前实现)
|
||||
|
||||
原生Gemini TTS功能通过智能检测自动启用,无需任何配置:
|
||||
|
||||
1. **自动启用**:
|
||||
```bash
|
||||
# 直接启动服务,原生TTS功能自动可用
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
2. **无需配置**:
|
||||
- 不需要环境变量
|
||||
- 不需要修改配置文件
|
||||
- 完全基于请求内容智能判断
|
||||
|
||||
### 工作原理
|
||||
|
||||
系统会智能检测请求内容:
|
||||
- **原生TTS请求**:包含 `responseModalities: ["AUDIO"]` 和 `speechConfig` → 使用TTS增强服务
|
||||
- **单人TTS**:包含 `voiceConfig.prebuiltVoiceConfig`
|
||||
- **多人TTS**:包含 `multiSpeakerVoiceConfig`
|
||||
- **普通请求**:非TTS模型 → 使用原有Gemini聊天服务
|
||||
|
||||
```python
|
||||
# app/router/gemini_routes.py 中的智能检测逻辑
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
# 使用TTS增强服务
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
return await tts_service.generate_content(...)
|
||||
# 否则使用原有服务
|
||||
```
|
||||
|
||||
## 📝 使用示例
|
||||
|
||||
### 1. 原生Gemini单人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `voiceConfig.prebuiltVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Hello, this is a single speaker test."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. 原生Gemini多人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `multiSpeakerVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Alice: Hello everyone, welcome to our show today.\nBob: Hi Alice, and hello to all our listeners! Today we are talking about AI development."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "Alice",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Puck"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"speaker": "Bob",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. OpenAI兼容TTS请求(使用原有服务)
|
||||
|
||||
OpenAI兼容格式的TTS请求使用不同的API路径,不受本模块影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1/audio/speech" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-token" \
|
||||
-d '{
|
||||
"model": "tts-1",
|
||||
"input": "这是一个OpenAI兼容格式的TTS测试。",
|
||||
"voice": "alloy"
|
||||
}' \
|
||||
--output openai_tts_test.wav
|
||||
```
|
||||
|
||||
**注意**:OpenAI兼容TTS请求:
|
||||
- 使用路径:`/v1/audio/speech`
|
||||
- 使用Authorization头而不是x-goog-api-key
|
||||
- 返回音频文件而不是JSON响应
|
||||
- 不受本模块的TTS增强服务影响
|
||||
|
||||
### 普通文本生成(使用原有服务)
|
||||
|
||||
非TTS模型的请求会使用原有的Gemini聊天服务,完全不受影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "请简单介绍一下人工智能的发展历程。"
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": 200,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 🔧 技术实现
|
||||
|
||||
### 继承关系
|
||||
|
||||
```
|
||||
GeminiChatService
|
||||
↓ (继承)
|
||||
TTSGeminiChatService
|
||||
├── 重写 generate_content() 方法
|
||||
├── 添加 _handle_tts_request() 方法
|
||||
└── 集成完整的日志记录功能
|
||||
|
||||
GeminiResponseHandler
|
||||
↓ (继承)
|
||||
TTSResponseHandler
|
||||
└── 重写 handle_response() 方法
|
||||
|
||||
GenerationConfig (Pydantic模型)
|
||||
↓ (扩展)
|
||||
TTSGenerationConfig
|
||||
├── responseModalities: List[str]
|
||||
└── speechConfig: Dict[str, Any]
|
||||
```
|
||||
|
||||
### 工作流程
|
||||
|
||||
1. **请求接收**:系统接收到API请求
|
||||
2. **智能检测**:
|
||||
- 检查模型名称是否包含 "tts"
|
||||
- 如果是TTS模型,从 `request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig`
|
||||
3. **服务选择**:
|
||||
- **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务
|
||||
- **普通请求**:使用原有 `GeminiChatService`
|
||||
4. **请求处理**:
|
||||
- **原生TTS**:使用 `_handle_tts_request()` 特殊处理
|
||||
- **其他请求**:使用标准 `generate_content()` 方法
|
||||
5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段(`responseModalities`, `speechConfig`)
|
||||
6. **API调用**:构建优化的payload并调用Gemini API
|
||||
7. **自动回退**:如果原生TTS处理失败,自动回退到标准服务
|
||||
8. **响应处理**:
|
||||
- **TTS响应**:检测音频数据,直接返回原始响应
|
||||
- **普通响应**:使用标准处理方法
|
||||
9. **日志记录**:记录请求时间、成功状态、错误信息到数据库
|
||||
|
||||
## 📊 功能特性
|
||||
|
||||
### ✅ 已实现功能
|
||||
|
||||
- **智能原生TTS支持**:支持单人和多人语音合成
|
||||
- **单人TTS**:支持 `voiceConfig.prebuiltVoiceConfig` 配置
|
||||
- **多人TTS**:支持 `multiSpeakerVoiceConfig` 配置
|
||||
- **智能检测机制**:自动检测所有原生Gemini TTS格式的请求
|
||||
- **动态模型选择**:支持用户在URL中指定不同TTS模型
|
||||
- **完全向后兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **自动回退机制**:原生TTS处理失败时自动使用标准服务
|
||||
- **完整日志记录**:请求日志、错误日志、性能监控
|
||||
- **API配额管理**:自动重试和密钥轮换
|
||||
- **零配置启用**:无需环境变量或配置文件修改
|
||||
- **错误处理**:完整的异常捕获和错误记录
|
||||
|
||||
### 🎵 支持的语音配置
|
||||
|
||||
#### 单人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 多人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "角色名称",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### API要求
|
||||
- 确保API密钥有TTS权限
|
||||
- TTS功能需要 `gemini-2.5-flash-preview-tts` 模型
|
||||
- 注意API配额限制(免费版每天15次)
|
||||
|
||||
### 性能考虑
|
||||
- TTS响应通常比文本响应更大(音频数据)
|
||||
- 建议监控API调用频率和成功率
|
||||
- 扩展功能不影响原始功能的性能和稳定性
|
||||
|
||||
### 部署建议
|
||||
- 生产环境建议先测试普通功能
|
||||
- 逐步启用TTS功能并监控日志
|
||||
- 定期检查API配额使用情况
|
||||
|
||||
## 📈 监控和调试
|
||||
|
||||
### 日志查看
|
||||
- **服务器日志**:查看TTS请求处理过程
|
||||
- **管理界面**:在"API 调用详情"中查看请求记录
|
||||
- **错误日志**:查看失败请求的详细信息
|
||||
|
||||
### 调试技巧
|
||||
```bash
|
||||
# 启用详细日志
|
||||
export LOG_LEVEL=DEBUG
|
||||
|
||||
# 查看实时日志
|
||||
tail -f logs/app.log
|
||||
|
||||
# 多人TTS功能无需配置,自动启用
|
||||
# 可通过请求内容智能检测
|
||||
```
|
||||
|
||||
## 🔄 TTS系统对比
|
||||
|
||||
项目中现在有三套TTS系统,各自服务不同的用途:
|
||||
|
||||
| TTS类型 | 路径 | 模型选择 | 语音配置 | 使用场景 | 我们的影响 |
|
||||
|---------|------|----------|----------|----------|------------|
|
||||
| **OpenAI兼容TTS** | `/v1/audio/speech` | 固定配置文件 | 单人语音 | OpenAI API兼容 | ✅ 无影响 |
|
||||
| **Gemini单人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 单人语音 | 原生Gemini TTS | ✅ 我们的增强 |
|
||||
| **Gemini多人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 多人语音 | 对话场景 | ✅ 我们的增强 |
|
||||
|
||||
### 智能路由机制
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[API请求] --> B{路径检查}
|
||||
B -->|/v1/audio/speech| C[OpenAI兼容TTS服务]
|
||||
B -->|/v1beta/models/{model}:generateContent| D{模型名包含'tts'?}
|
||||
D -->|否| E[标准Gemini聊天服务]
|
||||
D -->|是| F{包含responseModalities和speechConfig?}
|
||||
F -->|否| G[标准Gemini聊天服务]
|
||||
F -->|是| H[原生TTS增强服务]
|
||||
H --> I{处理成功?}
|
||||
I -->|是| J[返回原生TTS响应]
|
||||
I -->|否| K[自动回退到标准服务]
|
||||
C --> L[完成]
|
||||
E --> L
|
||||
G --> L
|
||||
J --> L
|
||||
K --> L
|
||||
```
|
||||
|
||||
## 🎉 成功案例
|
||||
|
||||
基于智能检测的原生Gemini TTS解决方案已经成功实现:
|
||||
|
||||
- ✅ **零配置启用**:无需任何环境变量或配置修改
|
||||
- ✅ **智能检测**:自动检测所有原生Gemini TTS格式的请求
|
||||
- ✅ **完全向后兼容**:所有原有TTS功能零影响
|
||||
- ✅ **动态模型选择**:支持用户指定不同TTS模型
|
||||
- ✅ **自动回退机制**:处理失败时自动使用标准服务
|
||||
- ✅ **单人和多人语音合成**:支持所有原生Gemini TTS场景
|
||||
- ✅ **完整日志记录**:可在管理界面查看所有请求
|
||||
- ✅ **错误处理完善**:API配额和重试机制
|
||||
- ✅ **易于维护**:更新原始代码无冲突
|
||||
|
||||
这个实现展示了如何在不修改原始代码的情况下,优雅地扩展复杂系统的功能,同时保持完美的向后兼容性。
|
||||
19
app/service/tts/native/__init__.py
Normal file
19
app/service/tts/native/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
原生Gemini TTS功能模块
|
||||
Native Gemini TTS functionality for both single and multi-speaker scenarios
|
||||
"""
|
||||
|
||||
from .tts_chat_service import TTSGeminiChatService
|
||||
from .tts_models import TTSGenerationConfig, MultiSpeakerVoiceConfig, SpeechConfig, TTSRequest
|
||||
from .tts_response_handler import TTSResponseHandler
|
||||
from .tts_routes import get_tts_chat_service
|
||||
|
||||
__all__ = [
|
||||
"TTSGeminiChatService",
|
||||
"TTSGenerationConfig",
|
||||
"MultiSpeakerVoiceConfig",
|
||||
"SpeechConfig",
|
||||
"TTSRequest",
|
||||
"TTSResponseHandler",
|
||||
"get_tts_chat_service"
|
||||
]
|
||||
151
app/service/tts/native/tts_chat_service.py
Normal file
151
app/service/tts/native/tts_chat_service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
原生Gemini TTS聊天服务扩展
|
||||
继承自原始聊天服务,添加原生Gemini TTS支持(单人和多人),保持向后兼容
|
||||
"""
|
||||
|
||||
import time
|
||||
import datetime
|
||||
from typing import Any, Dict
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_response_handler import TTSResponseHandler
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.database.services import add_request_log, add_error_log
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSGeminiChatService(GeminiChatService):
|
||||
"""
|
||||
支持TTS的Gemini聊天服务
|
||||
继承自原始的GeminiChatService,添加TTS功能
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager):
|
||||
"""
|
||||
初始化TTS聊天服务
|
||||
"""
|
||||
super().__init__(base_url, key_manager)
|
||||
# 使用TTS响应处理器替换原始处理器
|
||||
self.response_handler = TTSResponseHandler()
|
||||
logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support")
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成内容,支持TTS
|
||||
"""
|
||||
try:
|
||||
# 添加调试日志
|
||||
logger.info(f"TTS request model: {model}")
|
||||
logger.info(f"TTS request generationConfig: {request.generationConfig}")
|
||||
|
||||
# 检查是否是TTS模型,如果是,需要特殊处理
|
||||
if "tts" in model.lower():
|
||||
logger.info("Detected TTS model, applying TTS-specific processing")
|
||||
# 对于TTS模型,我们需要确保正确的字段被传递
|
||||
response = await self._handle_tts_request(model, request, api_key)
|
||||
return response
|
||||
else:
|
||||
# 对于非TTS模型,使用父类的方法
|
||||
response = await super().generate_content(model, request, api_key)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"TTS API call failed with error: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
处理TTS特定的请求,包含完整的日志记录功能
|
||||
"""
|
||||
# 记录开始时间和请求时间
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
# 构建TTS专用的payload - 不包含tools和safetySettings
|
||||
from app.service.chat.gemini_chat_service import _filter_empty_parts
|
||||
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
|
||||
# 构建TTS专用的简化payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
# 从request.generationConfig直接获取TTS相关字段
|
||||
if request.generationConfig:
|
||||
# 添加TTS特定字段
|
||||
if request.generationConfig.responseModalities:
|
||||
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
|
||||
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
|
||||
|
||||
if request.generationConfig.speechConfig:
|
||||
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
|
||||
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
|
||||
else:
|
||||
logger.warning("No generationConfig found in request, TTS fields may be missing")
|
||||
|
||||
logger.info(f"TTS payload before API call: {payload}")
|
||||
|
||||
# 调用API
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
|
||||
# 如果到达这里,说明API调用成功
|
||||
is_success = True
|
||||
status_code = 200
|
||||
|
||||
# 使用TTS响应处理器处理响应
|
||||
return self.response_handler.handle_response(response, model, False, None)
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
is_success = False
|
||||
error_msg = str(e)
|
||||
|
||||
# 尝试从错误消息中提取状态码
|
||||
import re
|
||||
match = re.search(r"status code (\d+)", error_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# 添加错误日志
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="tts-api-error",
|
||||
error_log=error_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.model_dump(exclude_none=False)
|
||||
)
|
||||
|
||||
logger.error(f"TTS API call failed: {error_msg}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 记录请求日志
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
37
app/service/tts/native/tts_config.py
Normal file
37
app/service/tts/native/tts_config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
TTS扩展配置
|
||||
控制是否启用TTS功能
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
class TTSConfig:
|
||||
"""TTS配置管理"""
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled() -> bool:
|
||||
"""
|
||||
检查是否启用TTS功能
|
||||
通过环境变量 ENABLE_TTS 控制,默认为 False
|
||||
"""
|
||||
return os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
@staticmethod
|
||||
def get_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""
|
||||
工厂方法:根据配置返回合适的聊天服务
|
||||
"""
|
||||
if TTSConfig.is_tts_enabled():
|
||||
return TTSGeminiChatService(base_url, key_manager)
|
||||
else:
|
||||
return GeminiChatService(base_url, key_manager)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""创建聊天服务实例"""
|
||||
return TTSConfig.get_chat_service(base_url, key_manager)
|
||||
36
app/service/tts/native/tts_models.py
Normal file
36
app/service/tts/native/tts_models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
原生Gemini TTS扩展数据模型
|
||||
继承自原始模型,添加原生Gemini TTS相关字段,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.domain.gemini_models import GenerationConfig as BaseGenerationConfig
|
||||
|
||||
|
||||
class TTSGenerationConfig(BaseGenerationConfig):
|
||||
"""
|
||||
支持TTS的生成配置类
|
||||
继承自原始的GenerationConfig,添加TTS相关字段
|
||||
"""
|
||||
# TTS 相关配置
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MultiSpeakerVoiceConfig(BaseModel):
|
||||
"""多人语音配置"""
|
||||
speakerVoiceConfigs: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SpeechConfig(BaseModel):
|
||||
"""语音配置"""
|
||||
multiSpeakerVoiceConfig: Optional[MultiSpeakerVoiceConfig] = None
|
||||
voiceConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
"""TTS请求模型"""
|
||||
contents: List[Dict[str, Any]]
|
||||
generationConfig: TTSGenerationConfig
|
||||
53
app/service/tts/native/tts_response_handler.py
Normal file
53
app/service/tts/native/tts_response_handler.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
原生Gemini TTS响应处理器扩展
|
||||
继承自原始响应处理器,添加原生Gemini TTS支持,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.log.logger import get_gemini_logger
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSResponseHandler(GeminiResponseHandler):
|
||||
"""
|
||||
支持TTS的响应处理器
|
||||
继承自原始的GeminiResponseHandler,添加TTS响应处理
|
||||
"""
|
||||
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理响应,支持TTS音频数据
|
||||
"""
|
||||
# 检查是否是TTS响应(包含音频数据)
|
||||
if self._is_tts_response(response):
|
||||
logger.info("Detected TTS response with audio data, returning original response")
|
||||
return response
|
||||
|
||||
# 对于非TTS响应,使用父类的处理方法
|
||||
return super().handle_response(response, model, stream, usage_metadata)
|
||||
|
||||
def _is_tts_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查是否是TTS响应
|
||||
"""
|
||||
try:
|
||||
if (response.get("candidates") and
|
||||
len(response["candidates"]) > 0 and
|
||||
response["candidates"][0].get("content") and
|
||||
response["candidates"][0]["content"].get("parts") and
|
||||
len(response["candidates"][0]["content"]["parts"]) > 0):
|
||||
|
||||
parts = response["candidates"][0]["content"]["parts"]
|
||||
for part in parts:
|
||||
if "inlineData" in part:
|
||||
mime_type = part["inlineData"].get("mimeType", "")
|
||||
if mime_type.startswith("audio/"):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking TTS response: {e}")
|
||||
return False
|
||||
24
app/service/tts/native/tts_routes.py
Normal file
24
app/service/tts/native/tts_routes.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
TTS路由扩展
|
||||
提供原生Gemini TTS增强服务,支持单人和多人语音
|
||||
"""
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from app.config.config import settings
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
"""获取密钥管理器实例"""
|
||||
return get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_tts_chat_service(key_manager: KeyManager = Depends(get_key_manager)) -> TTSGeminiChatService:
|
||||
"""
|
||||
获取原生Gemini TTS增强聊天服务实例,支持单人和多人语音
|
||||
"""
|
||||
return TTSGeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
95
app/service/tts/tts_service.py
Normal file
95
app/service/tts/tts_service.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import datetime
|
||||
import io
|
||||
import re
|
||||
import time
|
||||
import wave
|
||||
from typing import Optional
|
||||
|
||||
from google import genai
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import TTS_VOICE_NAMES
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.openai_models import TTSRequest
|
||||
from app.log.logger import get_openai_logger
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _create_wav_file(audio_data: bytes) -> bytes:
|
||||
"""Creates a WAV file in memory from raw audio data."""
|
||||
with io.BytesIO() as wav_file:
|
||||
with wave.open(wav_file, "wb") as wf:
|
||||
wf.setnchannels(1) # Mono
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(24000) # 24kHz sample rate
|
||||
wf.writeframes(audio_data)
|
||||
return wav_file.getvalue()
|
||||
|
||||
|
||||
class TTSService:
|
||||
async def create_tts(self, request: TTSRequest, api_key: str) -> Optional[bytes]:
|
||||
"""
|
||||
使用 Google Gemini SDK 创建音频。
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
error_log_msg = ""
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
response =await client.aio.models.generate_content(
|
||||
model=settings.TTS_MODEL,
|
||||
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
|
||||
config={
|
||||
"response_modalities": ["Audio"],
|
||||
"speech_config": {
|
||||
"voice_config": {
|
||||
"prebuilt_voice_config": {
|
||||
"voice_name": request.voice if request.voice in TTS_VOICE_NAMES else settings.TTS_VOICE_NAME
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
if (
|
||||
response.candidates
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].inline_data
|
||||
):
|
||||
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return _create_wav_file(raw_audio_data)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Generic error: {e}"
|
||||
logger.error(f"An error occurred in TTSService: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", str(e))
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
if not is_success:
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=settings.TTS_MODEL,
|
||||
error_type="google-tts",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.input
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=settings.TTS_MODEL,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
@@ -7,11 +7,7 @@ from app.log.logger import get_update_logger
|
||||
|
||||
logger = get_update_logger()
|
||||
|
||||
# GitHub repository details are read from settings (defined in app/config/config.py or environment variables)
|
||||
|
||||
# GITHUB_API_URL will be constructed inside the function to ensure settings are loaded
|
||||
|
||||
VERSION_FILE_PATH = "VERSION" # Path relative to project root
|
||||
VERSION_FILE_PATH = "VERSION"
|
||||
|
||||
async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
@@ -24,9 +20,6 @@ async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
- Optional[str]: 如果检查失败,则为错误消息,否则为 None。
|
||||
"""
|
||||
try:
|
||||
# Read current version from VERSION file
|
||||
# Ensure the path is correct relative to the execution context or use absolute path if needed
|
||||
# Assuming execution from project root d:/develop/pythonProjects/gemini-balance
|
||||
with open(VERSION_FILE_PATH, 'r', encoding='utf-8') as f:
|
||||
current_v = f.read().strip()
|
||||
if not current_v:
|
||||
@@ -41,25 +34,22 @@ async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
|
||||
logger.info(f"当前应用程序版本 (from {VERSION_FILE_PATH}): {current_v}")
|
||||
|
||||
# Check if repository details are configured in settings
|
||||
if not settings.GITHUB_REPO_OWNER or not settings.GITHUB_REPO_NAME or \
|
||||
settings.GITHUB_REPO_OWNER == "your_owner" or settings.GITHUB_REPO_NAME == "your_repo":
|
||||
logger.warning("GitHub repository owner/name not configured in settings. Skipping update check.")
|
||||
return False, None, "Update check skipped: Repository not configured in settings."
|
||||
|
||||
# Construct the API URL inside the function to ensure settings are loaded
|
||||
github_api_url = f"https://api.github.com/repos/{settings.GITHUB_REPO_OWNER}/{settings.GITHUB_REPO_NAME}/releases/latest"
|
||||
logger.debug(f"Checking for updates at URL: {github_api_url}") # Log the URL for debugging
|
||||
logger.debug(f"Checking for updates at URL: {github_api_url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# 添加 User-Agent 头,GitHub API 可能需要
|
||||
headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": f"{settings.GITHUB_REPO_NAME}-UpdateChecker/1.0" # Use repo name from settings for User-Agent
|
||||
"User-Agent": f"{settings.GITHUB_REPO_NAME}-UpdateChecker/1.0"
|
||||
}
|
||||
response = await client.get(github_api_url, headers=headers) # Use the locally constructed URL
|
||||
response.raise_for_status() # 对错误的 HTTP 状态码(4xx 或 5xx)抛出异常
|
||||
response = await client.get(github_api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
latest_release = response.json()
|
||||
latest_v_str = latest_release.get("tag_name")
|
||||
@@ -68,7 +58,6 @@ async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
logger.warning("在最新的 GitHub release 响应中找不到 'tag_name'。")
|
||||
return False, None, "无法从 GitHub 解析最新版本。"
|
||||
|
||||
# 移除 tag 名称中可能存在的 'v' 前缀
|
||||
if latest_v_str.startswith('v'):
|
||||
latest_v_str = latest_v_str[1:]
|
||||
|
||||
@@ -98,8 +87,6 @@ async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
logger.error(f"检查更新时发生网络错误: {e}")
|
||||
return False, None, "更新检查期间发生网络错误。"
|
||||
except version.InvalidVersion:
|
||||
# Note: latest_v_str might not be defined if the error occurs before fetching it.
|
||||
# Consider adding a check or default value for logging.
|
||||
latest_v_str_for_log = latest_v_str if 'latest_v_str' in locals() else 'N/A'
|
||||
logger.error(f"发现无效的版本格式。当前 (from {VERSION_FILE_PATH}): '{current_v}', 最新: '{latest_v_str_for_log}'")
|
||||
return False, None, "遇到无效的版本格式。"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -17,13 +17,27 @@ self.addEventListener('install', event => {
|
||||
|
||||
self.addEventListener('fetch', event => {
|
||||
event.respondWith(
|
||||
caches.match(event.request)
|
||||
.then(response => {
|
||||
if (response) {
|
||||
return response;
|
||||
}
|
||||
return fetch(event.request);
|
||||
})
|
||||
caches.open(CACHE_NAME).then(cache => {
|
||||
// 1. 尝试从缓存获取
|
||||
return cache.match(event.request).then(responseFromCache => {
|
||||
// 2. 同时从网络获取 (后台进行)
|
||||
const fetchPromise = fetch(event.request).then(responseFromNetwork => {
|
||||
// 3. 网络请求成功,更新缓存
|
||||
cache.put(event.request, responseFromNetwork.clone());
|
||||
return responseFromNetwork;
|
||||
}).catch(err => {
|
||||
// 网络请求失败时,可以选择记录错误或不执行任何操作
|
||||
console.error('Network fetch failed:', err);
|
||||
// 确保即使网络失败,如果缓存存在,我们仍然返回缓存
|
||||
// 如果缓存也不存在,则此 Promise 会 reject
|
||||
throw err;
|
||||
});
|
||||
|
||||
// 4. 如果缓存存在,立即返回缓存;否则等待网络响应
|
||||
// 后台的网络请求仍在进行,用于更新缓存
|
||||
return responseFromCache || fetchPromise;
|
||||
});
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -6,13 +6,14 @@
|
||||
<style>
|
||||
/* auth.html specific styles */
|
||||
.auth-glass-card { /* Renamed to avoid conflict if base.html has .glass-card */
|
||||
background: rgba(255, 255, 255, 0.85); /* Increased opacity */
|
||||
background: rgba(255, 255, 255, 0.95); /* High opacity white for light theme */
|
||||
backdrop-filter: blur(20px);
|
||||
-webkit-backdrop-filter: blur(20px);
|
||||
border: 1px solid rgba(255, 255, 255, 0.2);
|
||||
border: 1px solid rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 10px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04);
|
||||
}
|
||||
.auth-bg-gradient { /* Renamed to avoid conflict if base.html has .bg-gradient */
|
||||
background: linear-gradient(135deg, #4F46E5 0%, #7C3AED 50%, #EC4899 100%);
|
||||
background: #f8fafc; /* Light gray background for auth page */
|
||||
}
|
||||
/* .input-icon class removed, using direct Tailwind classes now */
|
||||
/* Keep button ripple effect if needed, or remove if base provides similar */
|
||||
@@ -49,7 +50,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h2 class="text-3xl font-extrabold text-center text-transparent bg-clip-text bg-gradient-to-r from-primary-600 to-primary-700 mb-8 animate-slide-down">
|
||||
<h2 class="text-3xl font-extrabold text-center text-gray-800 mb-8 animate-slide-down">
|
||||
<img src="/static/icons/logo.png" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
|
||||
Gemini Balance
|
||||
</h2>
|
||||
@@ -67,9 +68,9 @@
|
||||
>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
class="w-full py-4 rounded-xl bg-gradient-to-r from-primary-600 to-primary-700 text-white font-semibold transition duration-300 transform hover:-translate-y-1 hover:shadow-lg"
|
||||
<button
|
||||
type="submit"
|
||||
class="w-full py-4 rounded-xl bg-blue-600 hover:bg-blue-700 text-white font-semibold transition duration-300 transform hover:-translate-y-1 hover:shadow-lg"
|
||||
>
|
||||
登录
|
||||
</button>
|
||||
|
||||
@@ -1,279 +1,642 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>{% block title %}Gemini Balance{% endblock %}</title>
|
||||
<link rel="manifest" href="/static/manifest.json">
|
||||
<meta name="theme-color" content="#4F46E5">
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black">
|
||||
<meta name="apple-mobile-web-app-title" content="GBalance">
|
||||
<link rel="icon" href="/static/icons/icon-192x192.png">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
<link rel="manifest" href="/static/manifest.json" />
|
||||
<meta name="theme-color" content="#4F46E5" />
|
||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black" />
|
||||
<meta name="apple-mobile-web-app-title" content="GBalance" />
|
||||
<link rel="icon" href="/static/icons/icon-192x192.png" />
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"
|
||||
/>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
primary: {
|
||||
50: '#eef2ff',
|
||||
100: '#e0e7ff',
|
||||
200: '#c7d2fe',
|
||||
300: '#a5b4fc',
|
||||
400: '#818cf8',
|
||||
500: '#6366f1',
|
||||
600: '#4f46e5',
|
||||
700: '#4338ca',
|
||||
800: '#3730a3',
|
||||
900: '#312e81',
|
||||
},
|
||||
success: {
|
||||
50: '#ecfdf5',
|
||||
500: '#10b981',
|
||||
600: '#059669'
|
||||
},
|
||||
danger: {
|
||||
50: '#fef2f2',
|
||||
500: '#ef4444',
|
||||
600: '#dc2626'
|
||||
}
|
||||
},
|
||||
fontFamily: {
|
||||
sans: ['Inter', 'sans-serif'],
|
||||
mono: ['JetBrains Mono', 'SFMono-Regular', 'Menlo', 'Monaco', 'Consolas', 'monospace'],
|
||||
},
|
||||
animation: {
|
||||
'fade-in': 'fadeIn 0.5s ease-out',
|
||||
'slide-up': 'slideUp 0.5s ease-out',
|
||||
'slide-down': 'slideDown 0.5s ease-out',
|
||||
'shake': 'shake 0.5s ease-in-out',
|
||||
'spin': 'spin 1s linear infinite',
|
||||
},
|
||||
keyframes: {
|
||||
fadeIn: {
|
||||
'0%': { opacity: '0' },
|
||||
'100%': { opacity: '1' },
|
||||
},
|
||||
slideUp: {
|
||||
'0%': { transform: 'translateY(20px)', opacity: '0' },
|
||||
'100%': { transform: 'translateY(0)', opacity: '1' },
|
||||
},
|
||||
slideDown: {
|
||||
'0%': { transform: 'translateY(-20px)', opacity: '0' },
|
||||
'100%': { transform: 'translateY(0)', opacity: '1' },
|
||||
},
|
||||
shake: {
|
||||
'0%, 100%': { transform: 'translateX(0)' },
|
||||
'25%': { transform: 'translateX(-5px)' },
|
||||
'75%': { transform: 'translateX(5px)' },
|
||||
},
|
||||
spin: {
|
||||
'0%': { transform: 'rotate(0deg)' },
|
||||
'100%': { transform: 'rotate(360deg)' },
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
primary: {
|
||||
50: "#eef2ff",
|
||||
100: "#e0e7ff",
|
||||
200: "#c7d2fe",
|
||||
300: "#a5b4fc",
|
||||
400: "#818cf8",
|
||||
500: "#6366f1",
|
||||
600: "#4f46e5",
|
||||
700: "#4338ca",
|
||||
800: "#3730a3",
|
||||
900: "#312e81",
|
||||
},
|
||||
success: {
|
||||
50: "#ecfdf5",
|
||||
500: "#10b981",
|
||||
600: "#059669",
|
||||
},
|
||||
danger: {
|
||||
50: "#fef2f2",
|
||||
500: "#ef4444",
|
||||
600: "#dc2626",
|
||||
},
|
||||
},
|
||||
fontFamily: {
|
||||
sans: ["Inter", "sans-serif"],
|
||||
mono: [
|
||||
"JetBrains Mono",
|
||||
"SFMono-Regular",
|
||||
"Menlo",
|
||||
"Monaco",
|
||||
"Consolas",
|
||||
"monospace",
|
||||
],
|
||||
},
|
||||
animation: {
|
||||
"fade-in": "fadeIn 0.5s ease-out",
|
||||
"slide-up": "slideUp 0.5s ease-out",
|
||||
"slide-down": "slideDown 0.5s ease-out",
|
||||
shake: "shake 0.5s ease-in-out",
|
||||
spin: "spin 1s linear infinite",
|
||||
},
|
||||
keyframes: {
|
||||
fadeIn: {
|
||||
"0%": { opacity: "0" },
|
||||
"100%": { opacity: "1" },
|
||||
},
|
||||
slideUp: {
|
||||
"0%": { transform: "translateY(20px)", opacity: "0" },
|
||||
"100%": { transform: "translateY(0)", opacity: "1" },
|
||||
},
|
||||
slideDown: {
|
||||
"0%": { transform: "translateY(-20px)", opacity: "0" },
|
||||
"100%": { transform: "translateY(0)", opacity: "1" },
|
||||
},
|
||||
shake: {
|
||||
"0%, 100%": { transform: "translateX(0)" },
|
||||
"25%": { transform: "translateX(-5px)" },
|
||||
"75%": { transform: "translateX(5px)" },
|
||||
},
|
||||
spin: {
|
||||
"0%": { transform: "rotate(0deg)" },
|
||||
"100%": { transform: "rotate(360deg)" },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
</script>
|
||||
<style>
|
||||
.glass-card {
|
||||
background: rgba(255, 255, 255, 0.85); /* Slightly increased opacity for better readability */
|
||||
backdrop-filter: blur(16px);
|
||||
-webkit-backdrop-filter: blur(16px);
|
||||
border: 1px solid rgba(255, 255, 255, 0.18); /* Subtle border */
|
||||
}
|
||||
.bg-gradient {
|
||||
background: linear-gradient(135deg, #4F46E5 0%, #7C3AED 50%, #EC4899 100%);
|
||||
}
|
||||
/* Scrollbar styling */
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
::-webkit-scrollbar-track {
|
||||
background: rgba(243, 244, 246, 0.8); /* bg-gray-100 with opacity */
|
||||
border-radius: 10px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: rgba(79, 70, 229, 0.4); /* primary-600 with opacity */
|
||||
border-radius: 10px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(79, 70, 229, 0.6); /* primary-600 with more opacity */
|
||||
}
|
||||
/* Basic modal styles */
|
||||
.modal {
|
||||
display: none;
|
||||
position: fixed;
|
||||
z-index: 50;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0,0,0,0.5);
|
||||
backdrop-filter: blur(4px);
|
||||
}
|
||||
.modal.show {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
/* Loading spinner */
|
||||
.loading-spin {
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
@keyframes spin {
|
||||
from { transform: rotate(0deg); }
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
/* Notification */
|
||||
.notification {
|
||||
position: fixed;
|
||||
bottom: 5rem; /* Adjusted from bottom-20 */
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
padding: 0.75rem 1.25rem; /* px-5 py-3 */
|
||||
border-radius: 0.5rem; /* rounded-lg */
|
||||
background-color: rgba(0, 0, 0, 0.8);
|
||||
color: white;
|
||||
font-weight: 500; /* font-medium */
|
||||
z-index: 50;
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s ease-in-out, transform 0.3s ease-in-out;
|
||||
}
|
||||
.notification.show {
|
||||
opacity: 1;
|
||||
transform: translate(-50%, 0);
|
||||
}
|
||||
.notification.error {
|
||||
background-color: rgba(220, 38, 38, 0.8); /* danger-600 with opacity */
|
||||
}
|
||||
/* Scroll buttons */
|
||||
.scroll-buttons {
|
||||
position: fixed;
|
||||
right: 1.25rem; /* right-5 */
|
||||
bottom: 5rem; /* bottom-20 */
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem; /* gap-2 */
|
||||
z-index: 10;
|
||||
}
|
||||
.scroll-button {
|
||||
width: 2.5rem; /* w-10 */
|
||||
height: 2.5rem; /* h-10 */
|
||||
background-color: #4f46e5; /* bg-primary-600 */
|
||||
color: white;
|
||||
border-radius: 9999px; /* rounded-full */
|
||||
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); /* shadow-md */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.3s ease-in-out;
|
||||
}
|
||||
.scroll-button:hover {
|
||||
background-color: #4338ca; /* hover:bg-primary-700 */
|
||||
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); /* hover:shadow-lg */
|
||||
}
|
||||
{% block head_extra_styles %}
|
||||
{% endblock %}
|
||||
.glass-card {
|
||||
background: rgba(255, 255, 255, 0.95); /* High opacity white for light theme */
|
||||
backdrop-filter: blur(16px);
|
||||
-webkit-backdrop-filter: blur(16px);
|
||||
border: 1px solid rgba(0, 0, 0, 0.08); /* Light gray border */
|
||||
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
|
||||
}
|
||||
.bg-gradient {
|
||||
background: #ffffff; /* Clean white background */
|
||||
}
|
||||
/* Scrollbar styling */
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
::-webkit-scrollbar-track {
|
||||
background: rgba(243, 244, 246, 0.8); /* bg-gray-100 with opacity */
|
||||
border-radius: 10px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: rgba(107, 114, 128, 0.6); /* gray-500 for light theme */
|
||||
border-radius: 10px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(75, 85, 99, 0.8); /* gray-600 for light theme */
|
||||
}
|
||||
/* Basic modal styles */
|
||||
.modal {
|
||||
display: none;
|
||||
position: fixed;
|
||||
z-index: 50;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0,0,0,0.5);
|
||||
backdrop-filter: blur(4px);
|
||||
}
|
||||
.modal.show {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
/* Global modal content styling for light theme consistency */
|
||||
.modal .w-full[style*="background-color: rgba(70, 50, 150"],
|
||||
.modal .w-full[style*="background-color: rgba(80, 60, 160"] {
|
||||
background-color: rgba(255, 255, 255, 0.98) !important;
|
||||
color: #374151 !important; /* gray-700 */
|
||||
border: 1px solid rgba(0, 0, 0, 0.08) !important;
|
||||
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04) !important;
|
||||
}
|
||||
|
||||
/* Global modal text color fixes */
|
||||
.modal .text-gray-100, .modal h2.text-gray-100, .modal h3.text-gray-100 {
|
||||
color: #1f2937 !important; /* gray-800 */
|
||||
font-weight: 600 !important;
|
||||
}
|
||||
|
||||
.modal .text-gray-200, .modal .text-gray-300 {
|
||||
color: #6b7280 !important; /* gray-500 */
|
||||
}
|
||||
|
||||
.modal .text-gray-300:hover {
|
||||
color: #374151 !important; /* gray-700 */
|
||||
}
|
||||
|
||||
/* Global modal button styling */
|
||||
.modal .bg-violet-600, .modal button.bg-violet-600 {
|
||||
background-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
.modal .bg-violet-600:hover, .modal button.bg-violet-600:hover {
|
||||
background-color: #2563eb !important; /* blue-600 - darker light blue */
|
||||
}
|
||||
|
||||
/* Global modal blue button styling */
|
||||
.modal .bg-blue-500, .modal button.bg-blue-500,
|
||||
.modal .bg-blue-600, .modal button.bg-blue-600,
|
||||
.modal .bg-blue-700, .modal button.bg-blue-700 {
|
||||
background-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
.modal .bg-blue-500:hover, .modal button.bg-blue-500:hover,
|
||||
.modal .bg-blue-600:hover, .modal button.bg-blue-600:hover,
|
||||
.modal .bg-blue-700:hover, .modal button.bg-blue-700:hover {
|
||||
background-color: #2563eb !important; /* blue-600 - darker light blue */
|
||||
}
|
||||
|
||||
/* Global modal red button styling */
|
||||
.modal .bg-red-500, .modal button.bg-red-500,
|
||||
.modal .bg-red-600, .modal button.bg-red-600,
|
||||
.modal .bg-red-700, .modal button.bg-red-700 {
|
||||
background-color: #f87171 !important; /* red-400 - bright light red */
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
.modal .bg-red-500:hover, .modal button.bg-red-500:hover,
|
||||
.modal .bg-red-600:hover, .modal button.bg-red-600:hover,
|
||||
.modal .bg-red-700:hover, .modal button.bg-red-700:hover {
|
||||
background-color: #ef4444 !important; /* red-500 - darker bright light red */
|
||||
}
|
||||
|
||||
/* Global modal gray button styling */
|
||||
.modal .bg-gray-500, .modal button.bg-gray-500,
|
||||
.modal .bg-gray-600, .modal button.bg-gray-600,
|
||||
.modal .bg-gray-700, .modal button.bg-gray-700 {
|
||||
background-color: #e5e7eb !important; /* gray-200 - light gray */
|
||||
color: #374151 !important; /* gray-700 - dark text for contrast */
|
||||
}
|
||||
|
||||
.modal .bg-gray-500:hover, .modal button.bg-gray-500:hover,
|
||||
.modal .bg-gray-600:hover, .modal button.bg-gray-600:hover,
|
||||
.modal .bg-gray-700:hover, .modal button.bg-gray-700:hover {
|
||||
background-color: #d1d5db !important; /* gray-300 - darker light gray */
|
||||
color: #374151 !important; /* gray-700 - dark text for contrast */
|
||||
}
|
||||
|
||||
/* Comprehensive button contrast fixes */
|
||||
/* Ensure all dark background buttons have white text */
|
||||
.bg-blue-500, .bg-blue-600, .bg-blue-700, .bg-blue-800, .bg-blue-900,
|
||||
.bg-red-500, .bg-red-600, .bg-red-700, .bg-red-800, .bg-red-900,
|
||||
.bg-green-500, .bg-green-600, .bg-green-700, .bg-green-800, .bg-green-900,
|
||||
.bg-purple-500, .bg-purple-600, .bg-purple-700, .bg-purple-800, .bg-purple-900,
|
||||
.bg-indigo-500, .bg-indigo-600, .bg-indigo-700, .bg-indigo-800, .bg-indigo-900,
|
||||
.bg-violet-500, .bg-violet-600, .bg-violet-700, .bg-violet-800, .bg-violet-900,
|
||||
.bg-sky-500, .bg-sky-600, .bg-sky-700, .bg-sky-800, .bg-sky-900,
|
||||
.bg-teal-500, .bg-teal-600, .bg-teal-700, .bg-teal-800, .bg-teal-900,
|
||||
.bg-gray-700, .bg-gray-800, .bg-gray-900,
|
||||
.bg-slate-500, .bg-slate-600, .bg-slate-700, .bg-slate-800, .bg-slate-900 {
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
/* Ensure all light background buttons have dark text */
|
||||
.bg-gray-50, .bg-gray-100, .bg-gray-200, .bg-gray-300,
|
||||
.bg-white, .bg-transparent {
|
||||
color: #374151 !important; /* gray-700 */
|
||||
}
|
||||
|
||||
/* Fix button children text inheritance */
|
||||
.bg-blue-500 *, .bg-blue-600 *, .bg-blue-700 *, .bg-blue-800 *, .bg-blue-900 *,
|
||||
.bg-red-500 *, .bg-red-600 *, .bg-red-700 *, .bg-red-800 *, .bg-red-900 *,
|
||||
.bg-green-500 *, .bg-green-600 *, .bg-green-700 *, .bg-green-800 *, .bg-green-900 *,
|
||||
.bg-purple-500 *, .bg-purple-600 *, .bg-purple-700 *, .bg-purple-800 *, .bg-purple-900 *,
|
||||
.bg-violet-500 *, .bg-violet-600 *, .bg-violet-700 *, .bg-violet-800 *, .bg-violet-900 *,
|
||||
.bg-sky-500 *, .bg-sky-600 *, .bg-sky-700 *, .bg-sky-800 *, .bg-sky-900 *,
|
||||
.bg-teal-500 *, .bg-teal-600 *, .bg-teal-700 *, .bg-teal-800 *, .bg-teal-900 *,
|
||||
.bg-gray-700 *, .bg-gray-800 *, .bg-gray-900 *,
|
||||
.bg-slate-500 *, .bg-slate-600 *, .bg-slate-700 *, .bg-slate-800 *, .bg-slate-900 * {
|
||||
color: inherit !important;
|
||||
}
|
||||
|
||||
/* Global form element styling for consistency */
|
||||
select, input[type="text"], input[type="number"], input[type="search"],
|
||||
input[type="email"], input[type="password"], input[type="datetime-local"],
|
||||
textarea, .form-input, .form-select {
|
||||
background-color: rgba(255, 255, 255, 0.95) !important;
|
||||
color: #374151 !important; /* gray-700 */
|
||||
border: 1px solid rgba(0, 0, 0, 0.12) !important;
|
||||
border-radius: 0.375rem !important; /* rounded-md */
|
||||
}
|
||||
|
||||
select:focus, input:focus, textarea:focus,
|
||||
.form-input:focus, .form-select:focus {
|
||||
border-color: #3b82f6 !important; /* blue-500 */
|
||||
box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1) !important;
|
||||
outline: none !important;
|
||||
}
|
||||
|
||||
/* Fix dropdown option styling */
|
||||
select option {
|
||||
background-color: rgba(255, 255, 255, 0.98) !important;
|
||||
color: #374151 !important; /* gray-700 */
|
||||
padding: 8px !important;
|
||||
}
|
||||
|
||||
/* Fix pagination controls globally */
|
||||
.pagination-button, .pagination a, .pagination button {
|
||||
background-color: rgba(255, 255, 255, 0.9) !important;
|
||||
color: #374151 !important; /* gray-700 */
|
||||
border: 1px solid rgba(0, 0, 0, 0.08) !important;
|
||||
transition: all 0.15s ease-in-out !important;
|
||||
}
|
||||
|
||||
.pagination-button:hover, .pagination a:hover, .pagination button:hover {
|
||||
background-color: rgba(229, 231, 235, 1) !important; /* gray-200 */
|
||||
border-color: rgba(0, 0, 0, 0.12) !important;
|
||||
transform: translateY(-1px) !important;
|
||||
}
|
||||
|
||||
.pagination-button.active, .pagination a.active, .pagination button.active {
|
||||
background-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
color: #ffffff !important;
|
||||
border-color: #2563eb !important; /* blue-600 - darker light blue */
|
||||
font-weight: 600 !important;
|
||||
}
|
||||
/* Loading spinner */
|
||||
.loading-spin {
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
@keyframes spin {
|
||||
from { transform: rotate(0deg); }
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
/* Notification */
|
||||
.notification {
|
||||
position: fixed;
|
||||
bottom: 5rem; /* Adjusted from bottom-20 */
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
padding: 0.75rem 1.25rem; /* px-5 py-3 */
|
||||
border-radius: 0.5rem; /* rounded-lg */
|
||||
background-color: rgba(34, 197, 94, 0.95); /* green-500 for success */
|
||||
color: white;
|
||||
font-weight: 500; /* font-medium */
|
||||
z-index: 1000; /* Increased z-index */
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s ease-in-out, transform 0.3s ease-in-out;
|
||||
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
.notification.show {
|
||||
opacity: 1;
|
||||
transform: translate(-50%, 0);
|
||||
}
|
||||
.notification.error {
|
||||
background-color: rgba(239, 68, 68, 0.95); /* red-500 for error */
|
||||
}
|
||||
/* Scroll buttons */
|
||||
.scroll-buttons {
|
||||
position: fixed;
|
||||
right: 1.25rem; /* right-5 */
|
||||
bottom: 5rem; /* bottom-20 */
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem; /* gap-2 */
|
||||
z-index: 10;
|
||||
}
|
||||
.scroll-button {
|
||||
width: 2.5rem; /* w-10 */
|
||||
height: 2.5rem; /* h-10 */
|
||||
background-color: #3b82f6; /* blue-500 - light blue */
|
||||
color: white;
|
||||
border-radius: 9999px; /* rounded-full */
|
||||
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); /* shadow-md */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.3s ease-in-out;
|
||||
}
|
||||
.scroll-button:hover {
|
||||
background-color: #2563eb; /* blue-600 - darker light blue */
|
||||
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); /* hover:shadow-lg */
|
||||
}
|
||||
|
||||
/* Global overrides for light theme consistency */
|
||||
.text-gray-200, .text-gray-300, .text-gray-400 {
|
||||
color: #6b7280 !important; /* gray-500 for better contrast */
|
||||
}
|
||||
|
||||
/* Navigation and header improvements */
|
||||
.bg-primary-600, .bg-primary-700 {
|
||||
background-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
}
|
||||
|
||||
.text-primary-600, .text-primary-700 {
|
||||
color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
}
|
||||
|
||||
.border-primary-500, .focus\\:border-primary-500 {
|
||||
border-color: #3b82f6 !important; /* blue-500 */
|
||||
}
|
||||
|
||||
.ring-primary-200, .focus\\:ring-primary-200 {
|
||||
--tw-ring-color: rgba(59, 130, 246, 0.2) !important; /* blue-500 with opacity */
|
||||
}
|
||||
|
||||
/* Global purple to blue conversion */
|
||||
.bg-violet-50, .bg-violet-100, .bg-violet-200, .bg-violet-300, .bg-violet-400, .bg-violet-500, .bg-violet-600, .bg-violet-700, .bg-violet-800, .bg-violet-900 {
|
||||
background-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
}
|
||||
|
||||
.text-violet-50, .text-violet-100, .text-violet-200, .text-violet-300, .text-violet-400, .text-violet-500, .text-violet-600, .text-violet-700, .text-violet-800, .text-violet-900 {
|
||||
color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
}
|
||||
|
||||
.border-violet-50, .border-violet-100, .border-violet-200, .border-violet-300, .border-violet-400, .border-violet-500, .border-violet-600, .border-violet-700, .border-violet-800, .border-violet-900 {
|
||||
border-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
}
|
||||
|
||||
/* Global button color overrides */
|
||||
/* Blue buttons to light blue */
|
||||
.bg-blue-500, .bg-blue-600, .bg-blue-700, .bg-blue-800, .bg-blue-900,
|
||||
button.bg-blue-500, button.bg-blue-600, button.bg-blue-700, button.bg-blue-800, button.bg-blue-900 {
|
||||
background-color: #3b82f6 !important; /* blue-500 - light blue */
|
||||
}
|
||||
|
||||
.bg-blue-500:hover, .bg-blue-600:hover, .bg-blue-700:hover, .bg-blue-800:hover, .bg-blue-900:hover,
|
||||
button.bg-blue-500:hover, button.bg-blue-600:hover, button.bg-blue-700:hover, button.bg-blue-800:hover, button.bg-blue-900:hover,
|
||||
.hover\\:bg-blue-600:hover, .hover\\:bg-blue-700:hover, .hover\\:bg-blue-800:hover {
|
||||
background-color: #2563eb !important; /* blue-600 - darker light blue */
|
||||
}
|
||||
|
||||
/* Red buttons to bright light red */
|
||||
.bg-red-500, .bg-red-600, .bg-red-700, .bg-red-800, .bg-red-900,
|
||||
button.bg-red-500, button.bg-red-600, button.bg-red-700, button.bg-red-800, button.bg-red-900 {
|
||||
background-color: #f87171 !important; /* red-400 - bright light red */
|
||||
}
|
||||
|
||||
.bg-red-500:hover, .bg-red-600:hover, .bg-red-700:hover, .bg-red-800:hover, .bg-red-900:hover,
|
||||
button.bg-red-500:hover, button.bg-red-600:hover, button.bg-red-700:hover, button.bg-red-800:hover, button.bg-red-900:hover,
|
||||
.hover\\:bg-red-600:hover, .hover\\:bg-red-700:hover, .hover\\:bg-red-800:hover {
|
||||
background-color: #ef4444 !important; /* red-500 - darker bright light red */
|
||||
}
|
||||
|
||||
/* Gray buttons to light gray */
|
||||
.bg-gray-500, .bg-gray-600, .bg-gray-700, .bg-gray-800, .bg-gray-900,
|
||||
button.bg-gray-500, button.bg-gray-600, button.bg-gray-700, button.bg-gray-800, button.bg-gray-900 {
|
||||
background-color: #e5e7eb !important; /* gray-200 - light gray */
|
||||
color: #374151 !important; /* gray-700 - dark text for contrast */
|
||||
}
|
||||
|
||||
.bg-gray-500:hover, .bg-gray-600:hover, .bg-gray-700:hover, .bg-gray-800:hover, .bg-gray-900:hover,
|
||||
button.bg-gray-500:hover, button.bg-gray-600:hover, button.bg-gray-700:hover, button.bg-gray-800:hover, button.bg-gray-900:hover,
|
||||
.hover\\:bg-gray-600:hover, .hover\\:bg-gray-700:hover, .hover\\:bg-gray-800:hover {
|
||||
background-color: #d1d5db !important; /* gray-300 - darker light gray */
|
||||
color: #374151 !important; /* gray-700 - dark text for contrast */
|
||||
}
|
||||
|
||||
/* Ensure all text has proper contrast in light theme */
|
||||
.text-white {
|
||||
color: #374151 !important; /* gray-700 for better contrast on light backgrounds */
|
||||
}
|
||||
|
||||
/* Fix dark button text - ensure white text on dark backgrounds */
|
||||
.bg-blue-500, .bg-blue-600, .bg-blue-700, .bg-blue-800, .bg-blue-900,
|
||||
.bg-red-500, .bg-red-600, .bg-red-700, .bg-red-800, .bg-red-900,
|
||||
.bg-green-500, .bg-green-600, .bg-green-700, .bg-green-800, .bg-green-900,
|
||||
.bg-purple-500, .bg-purple-600, .bg-purple-700, .bg-purple-800, .bg-purple-900,
|
||||
.bg-indigo-500, .bg-indigo-600, .bg-indigo-700, .bg-indigo-800, .bg-indigo-900,
|
||||
.bg-gray-700, .bg-gray-800, .bg-gray-900,
|
||||
.bg-sky-500, .bg-sky-600, .bg-sky-700, .bg-sky-800, .bg-sky-900 {
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
/* Ensure buttons with dark backgrounds have white text */
|
||||
button.bg-blue-500, button.bg-blue-600, button.bg-blue-700,
|
||||
button.bg-red-500, button.bg-red-600, button.bg-red-700,
|
||||
button.bg-green-500, button.bg-green-600, button.bg-green-700,
|
||||
button.bg-sky-500, button.bg-sky-600, button.bg-sky-700,
|
||||
.btn-primary, .btn-danger, .btn-success, .btn-info {
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
/* Override any nested text color rules for dark buttons */
|
||||
.bg-blue-500 *, .bg-blue-600 *, .bg-blue-700 *,
|
||||
.bg-red-500 *, .bg-red-600 *, .bg-red-700 *,
|
||||
.bg-green-500 *, .bg-green-600 *, .bg-green-700 *,
|
||||
.bg-sky-500 *, .bg-sky-600 *, .bg-sky-700 * {
|
||||
color: inherit !important;
|
||||
}
|
||||
|
||||
{% block head_extra_styles %}
|
||||
{% endblock %}
|
||||
</style>
|
||||
{% block head_extra_scripts %}{% endblock %}
|
||||
</head>
|
||||
<body class="bg-gradient min-h-screen text-gray-800 pt-6 pb-16">
|
||||
|
||||
</head>
|
||||
<body class="bg-white min-h-screen text-gray-900 pt-6 pb-16">
|
||||
{% block content %}{% endblock %}
|
||||
|
||||
<!-- 底部版权 -->
|
||||
<div class="fixed bottom-0 left-0 w-full py-3 bg-white bg-opacity-80 backdrop-blur-md text-center text-sm text-gray-600 border-t border-gray-200">
|
||||
© <span id="copyright-year"></span> by
|
||||
<a href="https://linux.do/u/snaily" target="_blank" class="text-primary-600 hover:text-primary-800 transition duration-300">
|
||||
<img src="https://linux.do/user_avatar/linux.do/snaily/288/306510_2.gif" alt="snaily" class="inline-block w-5 h-5 rounded-full align-middle mr-1">snaily
|
||||
</a> |
|
||||
<a href="https://github.com/snailyp/gemini-balance" target="_blank" class="text-primary-600 hover:text-primary-800 transition duration-300">
|
||||
<i class="fab fa-github"></i> GitHub
|
||||
<div
|
||||
class="fixed bottom-0 left-0 w-full py-3 bg-white bg-opacity-95 backdrop-blur-md text-sm text-gray-800 border-t border-gray-200 flex flex-col items-center space-y-1"
|
||||
>
|
||||
<!-- 第一行 -->
|
||||
<div class="flex items-center justify-center space-x-2">
|
||||
<span>© <span id="copyright-year"></span> by</span>
|
||||
<a
|
||||
href="https://linux.do/u/snaily"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<img
|
||||
src="https://linux.do/user_avatar/linux.do/snaily/288/306510_2.gif"
|
||||
alt="snaily"
|
||||
class="inline-block w-5 h-5 rounded-full align-middle mr-1"
|
||||
/>snaily
|
||||
</a>
|
||||
{% if request and request.app.state.update_info %}
|
||||
{% set update_info = request.app.state.update_info %}
|
||||
<span class="mx-1">|</span>
|
||||
<span class="text-xs text-gray-500">v{{ update_info.current_version }}</span>
|
||||
{% if update_info.update_available %}
|
||||
<span class="mx-1">|</span>
|
||||
<a href="https://github.com/snailyp/gemini-balance/releases/latest" target="_blank" class="text-yellow-600 hover:text-yellow-800 transition duration-300 animate-pulse">
|
||||
<i class="fas fa-arrow-up"></i> 新版本: v{{ update_info.latest_version }}
|
||||
</a>
|
||||
{% elif update_info.error_message and update_info.error_message != 'Checking...' %}
|
||||
<span class="mx-1">|</span>
|
||||
<span class="text-xs text-red-500" title="{{ update_info.error_message }}">更新检查失败</span>
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
<span class="text-gray-400">|</span>
|
||||
<a
|
||||
href="https://github.com/snailyp/gemini-balance"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fab fa-github mr-1"></i> GitHub
|
||||
</a>
|
||||
</div>
|
||||
<!-- 第二行 -->
|
||||
<div class="flex items-center justify-center space-x-2 text-xs">
|
||||
<a
|
||||
href="https://gb-docs.snaily.top/guide/supportme.html"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fas fa-drumstick-bite text-yellow-600 mr-1"></i> 给作者加鸡腿
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<a
|
||||
href="https://gb-docs.snaily.top"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fas fa-book mr-1"></i> 在线文档
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<a
|
||||
href="https://t.me/+soaHax5lyI0wZDVl"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fab fa-telegram-plane mr-1"></i> 交流群
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<span class="text-yellow-600 font-semibold flex items-center">
|
||||
<i class="fas fa-exclamation-triangle mr-1"></i>免费项目,谨防诈骗
|
||||
</span>
|
||||
<span id="version-info-container" class="inline-flex items-center">
|
||||
<!-- Version info will be loaded here by JavaScript -->
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 通用JS -->
|
||||
<script>
|
||||
// 设置版权年份
|
||||
document.getElementById('copyright-year').textContent = new Date().getFullYear();
|
||||
// 设置版权年份
|
||||
document.getElementById("copyright-year").textContent =
|
||||
new Date().getFullYear();
|
||||
|
||||
// 滚动到顶部/底部函数 (如果页面需要)
|
||||
function scrollToTop() {
|
||||
window.scrollTo({ top: 0, behavior: 'smooth' });
|
||||
}
|
||||
function scrollToBottom() {
|
||||
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
|
||||
// 滚动到顶部/底部函数 (如果页面需要)
|
||||
function scrollToTop() {
|
||||
window.scrollTo({ top: 0, behavior: "smooth" });
|
||||
}
|
||||
function scrollToBottom() {
|
||||
window.scrollTo({
|
||||
top: document.body.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}
|
||||
|
||||
// 显示通知
|
||||
function showNotification(message, type = "success", duration = 3000) {
|
||||
const notification =
|
||||
document.getElementById("notification") ||
|
||||
createNotificationElement();
|
||||
if (!notification) return;
|
||||
|
||||
notification.textContent = message;
|
||||
notification.className = "notification show"; // Reset classes
|
||||
if (type === "error") {
|
||||
notification.classList.add("error");
|
||||
}
|
||||
|
||||
// 显示通知
|
||||
function showNotification(message, type = 'success', duration = 3000) {
|
||||
const notification = document.getElementById('notification') || createNotificationElement();
|
||||
if (!notification) return;
|
||||
|
||||
notification.textContent = message;
|
||||
notification.className = 'notification show'; // Reset classes
|
||||
if (type === 'error') {
|
||||
notification.classList.add('error');
|
||||
}
|
||||
|
||||
// Clear previous timeout if exists
|
||||
if (notification.timeoutId) {
|
||||
clearTimeout(notification.timeoutId);
|
||||
}
|
||||
|
||||
notification.timeoutId = setTimeout(() => {
|
||||
notification.classList.remove('show');
|
||||
// Optional: remove the element after fade out if dynamically created
|
||||
// setTimeout(() => notification.remove(), 300);
|
||||
}, duration);
|
||||
// Clear previous timeout if exists
|
||||
if (notification.timeoutId) {
|
||||
clearTimeout(notification.timeoutId);
|
||||
}
|
||||
|
||||
// Helper to create notification element if it doesn't exist
|
||||
function createNotificationElement() {
|
||||
let notification = document.getElementById('notification');
|
||||
if (!notification) {
|
||||
notification = document.createElement('div');
|
||||
notification.id = 'notification';
|
||||
notification.className = 'notification';
|
||||
document.body.appendChild(notification);
|
||||
}
|
||||
return notification;
|
||||
}
|
||||
notification.timeoutId = setTimeout(() => {
|
||||
notification.classList.remove("show");
|
||||
// Optional: remove the element after fade out if dynamically created
|
||||
// setTimeout(() => notification.remove(), 300);
|
||||
}, duration);
|
||||
}
|
||||
|
||||
// 页面刷新带加载状态
|
||||
function refreshPage(button) {
|
||||
if (button) {
|
||||
const icon = button.querySelector('i');
|
||||
if (icon) {
|
||||
icon.classList.add('loading-spin');
|
||||
}
|
||||
}
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 300); // Short delay to show spinner
|
||||
// Helper to create notification element if it doesn't exist
|
||||
function createNotificationElement() {
|
||||
let notification = document.getElementById("notification");
|
||||
if (!notification) {
|
||||
notification = document.createElement("div");
|
||||
notification.id = "notification";
|
||||
notification.className = "notification";
|
||||
document.body.appendChild(notification);
|
||||
}
|
||||
return notification;
|
||||
}
|
||||
|
||||
// 页面刷新带加载状态
|
||||
function refreshPage(button) {
|
||||
if (button) {
|
||||
const icon = button.querySelector("i");
|
||||
if (icon) {
|
||||
icon.classList.add("loading-spin");
|
||||
}
|
||||
}
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 300); // Short delay to show spinner
|
||||
}
|
||||
|
||||
// --- Version Check ---
|
||||
const versionInfoContainer = document.getElementById(
|
||||
"version-info-container"
|
||||
);
|
||||
|
||||
async function fetchVersionInfo() {
|
||||
if (!versionInfoContainer) return;
|
||||
versionInfoContainer.innerHTML =
|
||||
'<span class="mx-1">|</span><span class="text-xs text-gray-700">检查更新中...</span>'; // Initial loading state
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/version/check");
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
|
||||
let versionHtml = `<span class="mx-1">|</span><span class="text-xs text-gray-800">v${data.current_version}</span>`;
|
||||
if (data.update_available) {
|
||||
versionHtml += `
|
||||
<span class="mx-1">|</span>
|
||||
<a href="https://github.com/snailyp/gemini-balance/releases/latest" target="_blank" class="text-yellow-600 hover:text-yellow-800 transition duration-300 animate-pulse">
|
||||
<i class="fas fa-arrow-up"></i> 新版本: v${data.latest_version}
|
||||
</a>`;
|
||||
} else if (data.error_message) {
|
||||
versionHtml += `
|
||||
<span class="mx-1">|</span>
|
||||
<span class="text-xs text-red-500" title="${data.error_message}">更新检查失败</span>`;
|
||||
} else {
|
||||
versionHtml += `<span class="mx-1">|</span><span class="text-xs text-green-500">已是最新</span>`; // Indicate up-to-date
|
||||
}
|
||||
versionInfoContainer.innerHTML = versionHtml;
|
||||
} catch (error) {
|
||||
console.error("Error fetching version info:", error);
|
||||
versionInfoContainer.innerHTML = `<span class="mx-1">|</span><span class="text-xs text-red-500" title="无法连接到服务器或解析响应">更新检查失败</span>`;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch immediately on load
|
||||
fetchVersionInfo();
|
||||
|
||||
// Fetch periodically (e.g., every hour)
|
||||
setInterval(fetchVersionInfo, 3600000); // 3600000 ms = 1 hour
|
||||
</script>
|
||||
{% block body_scripts %}{% endblock %}
|
||||
</body>
|
||||
</html>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,16 @@ import re
|
||||
import base64
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
|
||||
|
||||
helper_logger = logging.getLogger("app.utils")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"
|
||||
|
||||
|
||||
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
@@ -146,3 +153,20 @@ def is_valid_api_key(key: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_current_version(default_version: str = "0.0.0") -> str:
|
||||
"""Reads the current version from the VERSION file."""
|
||||
version_file = VERSION_FILE_PATH
|
||||
try:
|
||||
with version_file.open('r', encoding='utf-8') as f:
|
||||
version = f.read().strip()
|
||||
if not version:
|
||||
helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
return version
|
||||
except FileNotFoundError:
|
||||
helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
except IOError as e:
|
||||
helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
|
||||
@@ -261,18 +261,20 @@ class PicGoUploader(ImageUploader):
|
||||
|
||||
class CloudFlareImgBedUploader(ImageUploader):
|
||||
"""CloudFlare图床上传器"""
|
||||
|
||||
def __init__(self, auth_code: str, api_url: str):
|
||||
|
||||
def __init__(self, auth_code: str, api_url: str, upload_folder: str = ""):
|
||||
"""
|
||||
初始化CloudFlare图床上传器
|
||||
|
||||
Args:
|
||||
auth_code: 认证码
|
||||
api_url: 上传API地址
|
||||
upload_folder: 上传文件夹路径(可选)
|
||||
"""
|
||||
self.auth_code = auth_code
|
||||
self.api_url = api_url
|
||||
|
||||
self.upload_folder = upload_folder
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
"""
|
||||
上传图片到CloudFlare图床
|
||||
@@ -288,12 +290,16 @@ class CloudFlareImgBedUploader(ImageUploader):
|
||||
UploadError: 上传失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
# 准备请求URL(添加认证码参数,如果存在)
|
||||
# 准备请求URL参数
|
||||
params = []
|
||||
if self.upload_folder:
|
||||
params.append(f"uploadFolder={self.upload_folder}")
|
||||
if self.auth_code:
|
||||
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
|
||||
else:
|
||||
request_url = f"{self.api_url}?uploadNameType=origin"
|
||||
|
||||
params.append(f"authCode={self.auth_code}")
|
||||
params.append("uploadNameType=origin")
|
||||
|
||||
request_url = f"{self.api_url}?{'&'.join(params)}"
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
"file": (filename, file)
|
||||
@@ -388,6 +394,7 @@ class ImageUploaderFactory:
|
||||
elif provider == "cloudflare_imgbed":
|
||||
return CloudFlareImgBedUploader(
|
||||
credentials["auth_code"],
|
||||
credentials["base_url"]
|
||||
credentials["base_url"],
|
||||
credentials.get("upload_folder", ""),
|
||||
)
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
71
files/dataocean.svg
Normal file
71
files/dataocean.svg
Normal file
@@ -0,0 +1,71 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 19.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 603 103" style="enable-background:new 0 0 603 103;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#0080FF;}
|
||||
.st1{fill-rule:evenodd;clip-rule:evenodd;fill:#0080FF;}
|
||||
</style>
|
||||
<g id="XMLID_2369_">
|
||||
<g id="XMLID_2638_">
|
||||
<g id="XMLID_2639_">
|
||||
<g>
|
||||
<g id="XMLID_44_">
|
||||
<g id="XMLID_48_">
|
||||
<path id="XMLID_49_" class="st0" d="M52.1,102.1l0-19.6c20.8,0,36.8-20.6,28.9-42.4C78,32,71.6,25.5,63.5,22.6
|
||||
c-21.8-7.9-42.4,8.1-42.4,28.9c0,0,0,0,0,0l-19.6,0c0-33.1,32-58.9,66.7-48.1c15.2,4.7,27.2,16.8,31.9,31.9
|
||||
C110.9,70.1,85.2,102.1,52.1,102.1z"/>
|
||||
</g>
|
||||
<polygon id="XMLID_47_" class="st1" points="52.1,82.5 32.6,82.5 32.6,63 32.6,63 52.1,63 52.1,63 "/>
|
||||
<polygon id="XMLID_46_" class="st1" points="32.6,97.5 17.6,97.5 17.6,97.5 17.6,82.5 32.6,82.5 32.6,97.5 "/>
|
||||
<polygon id="XMLID_45_" class="st1" points="17.6,82.5 5,82.5 5,82.5 5,70 5,70 17.6,70 17.6,70 "/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<g id="XMLID_2370_">
|
||||
<path id="XMLID_2635_" class="st0" d="M181.5,30.2c-5.8-4-13-6.1-21.4-6.1h-18.3v58.1h18.3c8.4,0,15.6-2.1,21.4-6.4
|
||||
c3.2-2.2,5.7-5.4,7.4-9.3c1.7-3.9,2.6-8.5,2.6-13.7c0-5.1-0.9-9.7-2.6-13.6C187.2,35.4,184.7,32.3,181.5,30.2z M152.5,34h5.8
|
||||
c6.4,0,11.7,1.3,15.7,3.7c4.4,2.7,6.7,7.8,6.7,15.1c0,7.6-2.3,12.9-6.7,15.8h0c-3.8,2.5-9.1,3.8-15.6,3.8h-5.8V34z"/>
|
||||
<path id="XMLID_2634_" class="st0" d="M204.3,23.4c-1.8,0-3.3,0.6-4.5,1.8c-1.2,1.2-1.9,2.7-1.9,4.4c0,1.8,0.6,3.3,1.9,4.5
|
||||
c1.2,1.2,2.7,1.9,4.5,1.9c1.8,0,3.3-0.6,4.5-1.9c1.2-1.2,1.9-2.8,1.9-4.5c0-1.8-0.6-3.3-1.9-4.4C207.6,24,206,23.4,204.3,23.4z"/>
|
||||
<rect id="XMLID_2564_" x="199" y="41.3" class="st0" width="10.3" height="41"/>
|
||||
<path id="XMLID_2561_" class="st0" d="M246.8,44.7c-3.1-2.8-6.6-4.4-10.3-4.4c-5.7,0-10.4,2-14.1,5.8c-3.7,3.8-5.5,8.8-5.5,14.7
|
||||
c0,5.8,1.8,10.7,5.5,14.7c3.7,3.8,8.4,5.8,14.1,5.8c4,0,7.4-1.1,10.2-3.3V79c0,3.4-0.9,6-2.7,7.9c-1.8,1.8-4.3,2.7-7.4,2.7
|
||||
c-4.8,0-7.7-1.9-11.4-6.8l-7,6.7l0.2,0.3c1.5,2.1,3.8,4.2,6.9,6.2c3.1,2,6.9,3,11.5,3c6.1,0,11.1-1.9,14.7-5.6
|
||||
c3.7-3.7,5.5-8.7,5.5-14.9V41.3h-10.1V44.7z M244.1,68.9c-1.8,2-4.1,3-7.1,3c-3,0-5.3-1-7-3c-1.8-2-2.7-4.7-2.7-8
|
||||
c0-3.3,0.9-6.1,2.7-8.1c1.8-2,4.1-3.1,7-3.1c3,0,5.3,1,7.1,3.1c1.8,2,2.7,4.8,2.7,8.1C246.8,64.2,245.8,66.9,244.1,68.9z"/>
|
||||
<rect id="XMLID_2560_" x="265.7" y="41.3" class="st0" width="10.3" height="41"/>
|
||||
<path id="XMLID_2552_" class="st0" d="M271,23.4c-1.8,0-3.3,0.6-4.5,1.8c-1.2,1.2-1.9,2.7-1.9,4.4c0,1.8,0.6,3.3,1.9,4.5
|
||||
c1.2,1.2,2.7,1.9,4.5,1.9c1.8,0,3.3-0.6,4.5-1.9c1.2-1.2,1.9-2.8,1.9-4.5c0-1.8-0.6-3.3-1.9-4.4C274.3,24,272.7,23.4,271,23.4z"/>
|
||||
<path id="XMLID_2509_" class="st0" d="M298.6,30.3h-10.1v11.1h-5.9v9.4h5.9v17c0,5.3,1.1,9.1,3.2,11.3c2.1,2.2,5.8,3.3,11.1,3.3
|
||||
c1.7,0,3.4-0.1,5-0.2l0.5,0v-9.4l-3.5,0.2c-2.5,0-4.1-0.4-4.9-1.3c-0.8-0.9-1.2-2.7-1.2-5.4V50.7h9.6v-9.4h-9.6V30.3z"/>
|
||||
<rect id="XMLID_2508_" x="356.5" y="24.1" class="st0" width="10.3" height="58.1"/>
|
||||
<path id="XMLID_2470_" class="st0" d="M470.9,67.6c-1.8,2.1-3.7,3.9-5.2,4.8v0c-1.4,0.9-3.2,1.4-5.3,1.4c-3,0-5.5-1.1-7.5-3.4
|
||||
c-2-2.3-3-5.2-3-8.7s1-6.4,2.9-8.6c2-2.3,4.4-3.4,7.4-3.4c3.3,0,6.8,2.1,9.8,5.6l6.8-6.5l0,0c-4.4-5.8-10.1-8.5-16.9-8.5
|
||||
c-5.7,0-10.6,2.1-14.6,6.1c-4,4-6,9.2-6,15.3s2,11.2,6,15.3c4,4.1,8.9,6.1,14.6,6.1c7.5,0,13.5-3.2,17.5-9.1L470.9,67.6z"/>
|
||||
<path id="XMLID_2460_" class="st0" d="M513.2,47c-1.5-2-3.5-3.7-5.9-4.9c-2.5-1.2-5.3-1.8-8.5-1.8c-5.8,0-10.5,2.1-14,6.3
|
||||
c-3.4,4.2-5.2,9.3-5.2,15.4c0,6.2,1.9,11.3,5.7,15.3c3.7,3.9,8.8,5.9,14.9,5.9c6.9,0,12.7-2.8,16.9-8.4l0.2-0.3l-6.7-6.5l0,0
|
||||
c-0.6,0.8-1.5,1.6-2.3,2.4c-1,1-2,1.7-3,2.2c-1.5,0.8-3.3,1.1-5.2,1.1c-2.9,0-5.2-0.8-7-2.5c-1.7-1.5-2.7-3.6-2.9-6.2h27.3
|
||||
l0.1-3.8c0-2.7-0.4-5.2-1.1-7.6C515.8,51.3,514.7,49.1,513.2,47z M490.7,56.7c0.5-2,1.4-3.6,2.7-4.9c1.4-1.4,3.2-2.1,5.4-2.1
|
||||
c2.5,0,4.4,0.7,5.7,2.1c1.2,1.3,1.9,2.9,2.1,4.8H490.7z"/>
|
||||
<path id="XMLID_2456_" class="st0" d="M552.8,44.4L552.8,44.4c-3.1-2.7-7.4-4-12.8-4c-3.4,0-6.6,0.8-9.5,2.2
|
||||
c-2.7,1.4-5.3,3.6-7,6.6l0.1,0.1l6.6,6.3c2.7-4.3,5.7-5.8,9.7-5.8c2.2,0,3.9,0.6,5.3,1.7c1.4,1.1,2,2.6,2,4.4v2
|
||||
c-2.6-0.8-5.1-1.2-7.6-1.2c-5.1,0-9.3,1.2-12.4,3.6c-3.1,2.4-4.7,5.9-4.7,10.2c0,3.8,1.3,7,4,9.3c2.7,2.2,6,3.4,9.9,3.4
|
||||
c3.9,0,7.6-1.6,10.9-4.3v3.4h10.1V55.9C557.6,51,556,47.1,552.8,44.4z M534.5,66.6c1.2-0.8,2.8-1.2,4.9-1.2c2.5,0,5.1,0.5,7.8,1.5
|
||||
v4C545,73,542,74,538.3,74c-1.8,0-3.2-0.4-4.1-1.2c-0.9-0.8-1.4-1.7-1.4-3C532.8,68.5,533.4,67.4,534.5,66.6z"/>
|
||||
<path id="XMLID_2454_" class="st0" d="M597.2,45.2c-2.9-3.2-6.9-4.8-12-4.8c-4.1,0-7.4,1.2-9.9,3.5v-2.5h-10.1v41h10.3V59.7
|
||||
c0-3.1,0.7-5.6,2.2-7.3c1.5-1.8,3.4-2.6,6.1-2.6c2.3,0,4.1,0.8,5.4,2.3c1.3,1.6,2,3.7,2,6.4v23.7h10.3V58.5
|
||||
C601.5,52.9,600.1,48.4,597.2,45.2z"/>
|
||||
<path id="XMLID_2450_" class="st0" d="M343.6,44.4L343.6,44.4c-3.1-2.7-7.4-4-12.8-4c-3.4,0-6.6,0.8-9.5,2.2
|
||||
c-2.7,1.4-5.3,3.6-7,6.6l0.1,0.1l6.6,6.3c2.7-4.3,5.7-5.8,9.7-5.8c2.2,0,3.9,0.6,5.3,1.7c1.4,1.1,2,2.6,2,4.4v2
|
||||
c-2.6-0.8-5.1-1.2-7.6-1.2c-5.1,0-9.3,1.2-12.4,3.6c-3.1,2.4-4.7,5.9-4.7,10.2c0,3.8,1.3,7,4,9.3c2.7,2.2,6,3.4,9.9,3.4
|
||||
c3.9,0,7.6-1.6,10.9-4.3v3.4h10.1V55.9C348.3,51,346.7,47.1,343.6,44.4z M325.3,66.6c1.2-0.8,2.8-1.2,4.9-1.2
|
||||
c2.5,0,5.1,0.5,7.8,1.5v4c-2.2,2.1-5.2,3.1-8.9,3.1c-1.8,0-3.2-0.4-4.1-1.2c-0.9-0.8-1.4-1.7-1.4-3
|
||||
C323.6,68.5,324.1,67.4,325.3,66.6z"/>
|
||||
<path id="XMLID_2371_" class="st0" d="M404.2,83.1c-16.5,0-30-13.4-30-30s13.4-30,30-30c16.5,0,30,13.4,30,30
|
||||
S420.7,83.1,404.2,83.1z M404.2,33.8c-10.7,0-19.4,8.7-19.4,19.4s8.7,19.4,19.4,19.4c10.7,0,19.4-8.7,19.4-19.4
|
||||
S414.9,33.8,404.2,33.8z"/>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.7 KiB |
@@ -1,5 +1,5 @@
|
||||
fastapi
|
||||
httpx
|
||||
httpx[socks]
|
||||
openai
|
||||
pydantic
|
||||
pydantic_settings
|
||||
@@ -9,13 +9,12 @@ uvicorn
|
||||
google-genai
|
||||
jinja2
|
||||
python-multipart
|
||||
cryptography # 支持 MySQL 8+ caching_sha2_password 验证
|
||||
# 数据库相关依赖
|
||||
cryptography
|
||||
pymysql
|
||||
sqlalchemy
|
||||
aiomysql
|
||||
aiosqlite
|
||||
databases
|
||||
python-dotenv
|
||||
apscheduler # 添加定时任务库
|
||||
|
||||
apscheduler
|
||||
packaging
|
||||
|
||||
Reference in New Issue
Block a user