Compare commits

..

35 Commits

Author SHA1 Message Date
snaily
150824938c chore(version): 更新版本号至 2.2.8 2025-09-23 22:10:21 +08:00
snaily
ccaea40281 chore(version): 更新版本号至 2.2.7.1 2025-09-23 21:27:21 +08:00
snaily
9d8e77c9f7 fix: 修复全量检测key失效问题 2025-09-19 14:20:06 +08:00
snaily
19941f7f50 feat(logging): 增加日志记录以显示使用的授权令牌 2025-09-18 10:59:38 +08:00
snaily
d6981c204a build(docker): 简化 Dockerfile,移除多阶段构建 2025-09-18 10:49:03 +08:00
snaily
d386cc7180 build(docker): 优化 Dockerfile 以实现多阶段构建
将 Dockerfile 修改为多阶段构建,以减小最终镜像的体积并提高构建效率。

第一阶段(builder)负责安装 Python 依赖项。
第二阶段创建最终的生产镜像,仅从构建器阶段复制已安装的依赖包和应用程序代码,不包含构建时的工具和缓存。

这种方法可以显著减小镜像大小,并利用 Docker 的层缓存机制,仅在 `requirements.txt` 发生变化时才重新安装依赖。
2025-09-18 10:21:28 +08:00
snaily
bed3647424 chore(version): 更新版本号至 2.2.7 2025-09-18 10:00:48 +08:00
snaily
95b5acad66 refactor(api): 优化错误处理和日志记录
对多个模块进行了重构,以改进错误处理和日志记录机制。

主要变更包括:
- 在 `gemini_routes` 中,现在会返回更具体的错误信息,包括错误码和错误消息,而不仅仅是异常的字符串表示。
- 在 `api_client` 中,简化了 Gemini API 客户端的错误处理逻辑,移除了冗余的 `try...except` 块,让异常直接向上抛出。
- 在多个服务(如 `openai_chat_service`, `embedding_service`, `tts_service` 等)中,增加了根据配置项 `ERROR_LOG_RECORD_REQUEST_BODY` 来决定是否记录请求体的逻辑,以增强隐私和性能控制。
- 在前端 `keys_status.js` 中,更新了密钥验证结果的处理逻辑,以适应后端返回的新的错误对象结构(包含 `error_code` 和 `error_message`),并移除了冗余的 `executeVerifyAllKeys` 函数。
2025-09-18 09:59:32 +08:00
snaily
68b65814bc chore(version): 更新版本号至 2.2.6 2025-09-18 07:37:37 +08:00
snaily
88f5b33018 docs(config): 优化错误日志配置选项的说明文案
将错误日志记录请求体选项的提示文案从"关闭可避免敏感数据入库"
更新为"关闭可减少大量磁盘空间占用",更准确地描述该功能的作用
2025-09-18 06:41:09 +08:00
snaily
8c62c8121d feat(static): 实现静态资源版本化和模板全局变量支持
- 在Dockerfile中添加默认环境变量配置
- 新增静态资源URL版本化管理功能
- 更新所有模板文件使用static_url函数替代硬编码路径
- 优化错误日志页面移动端按钮布局和响应式设计
- 简化异常处理器返回格式

BREAKING CHANGE: 静态资源URL格式变更,需要重新部署以确保资源正确加载
2025-09-18 06:29:45 +08:00
snaily
05762cb6a5 feat(config):更新默认模型和相关配置
更新默认模型和相关配置:
- 将默认测试模型从 gemini-1.5-flash 更新为 gemini-2.5-flash-lite
- 更新思考模型列表至 gemini-2.5-flash 和 gemini-2.5-pro
- 添加新的图像模型 gemini-2.5-flash-image-preview
- 更新搜索模型配置以支持最新的 Gemini 2.5 系列
- 同步更新文档中的模型配置说明
2025-09-18 05:24:29 +08:00
snaily
78f38cc981 refactor(scheduler): 优化定时任务配置和时间处理
- 支持CHECK_INTERVAL_HOURS设置为0以禁用密钥检查任务
- 调整日志清理任务执行时间从凌晨3点改为0点
- 移除timezone依赖,使用本地时间处理
- 优化代码格式和导入顺序
- 为配置编辑器添加CHECK_INTERVAL_HOURS输入验证
- 改进UI布局,为关键配置项添加警告提示
2025-09-18 05:14:43 +08:00
snaily
79f47c315e style(ui): 重构配置编辑器字段描述显示方式
将所有配置字段的描述文本从底部小字说明改为标签旁的问号图标提示,提升界面简洁度和用户体验。同时优化了数组容器和独立输入框的边框样式区分。
2025-09-18 04:49:04 +08:00
snaily
708fb1604b feat(config): 新增错误日志请求体记录开关(默认关闭)
- 新增环境变量 ERROR_LOG_RECORD_REQUEST_BODY,默认 false
- Settings 增加该配置,并在各服务写入错误日志时按开关决定是否
  入库请求体,降低敏感信息泄露风险
- 配置编辑页新增对应开关,前端初始化默认值;.env.example、
  README/README_ZH 同步更新
- db: add_error_log 支持 None 请求体并更稳健解析字符串/字典
- perf(db): 将错误日志批量删除 batch_size 从 500 下调到 200,
  兼容 SQLite/MySQL 参数上限并提升稳定性
- docs: 补充 aliyun_oss 上传提供商与 OSS 配置示例
- style: 轻微代码格式化与导入顺序优化
2025-09-18 04:21:28 +08:00
snaily
7dbd3ad693 perf(db): 优化错误日志删除以支持大数据量
将 `delete_all_error_logs` 函数的实现从一次性删除所有记录改为分批删除。这可以防止在处理大量日志时因数据库事务过长而导致的超时或性能问题。

- 每次从数据库中获取一批日志ID,然后根据ID进行删除。
- 在每个批次处理后,使用 `asyncio.sleep(0)` 将控制权交还给事件循环,避免长时间阻塞。
- 批次大小设置为500,以兼容不同数据库(如SQLite)对SQL参数数量的限制。
- 函数现在返回实际删除的日志总数,而不是一个固定的成功指示符。
2025-09-18 03:33:59 +08:00
snaily
67dd1af583 refactor(error): 统一异常处理和响应格式
这次提交重构了整个应用的异常处理机制,保证了处理方式的一致性,还能提供更详细的错误信息。

主要改动包括:
- 修改了 `ApiClient`,现在抛出的异常会同时包含状态码和消息。这样上游服务就能传递准确的 HTTP 错误响应啦。
- 更新了所有服务层(`gemini`、`openai`、`vertex`、`embedding`),现在会捕获这些结构化的异常,不再从字符串里解析错误消息了。
- 增强了路由级别的错误处理,特别是针对流式端点,能正确捕获初始化错误,并返回结构化的 JSON 错误响应,而不是格式错误的 SSE 事件。
- 在所有 API 路由中添加了 `allowed_token` 的日志记录,方便追踪和调试授权问题。
- 还有一些常规的代码清理,比如调整了 import 顺序和格式化代码,提高了可读性和可维护性。
2025-09-18 03:11:45 +08:00
snaily
e104a50cf4 Merge pull request #347 from bbbugg:Add-final-SSE-error
Fix: Gemini streaming returns a structured error instead of empty responses
2025-09-17 23:58:53 +08:00
snaily
6b9647813b Merge pull request #360 from minguncle:feat-support-aliyunoss
Feat support aliyunoss
2025-09-17 20:43:27 +08:00
wanglinjie
f863e3065b Merge remote-tracking branch 'origin/main' 2025-09-03 09:38:15 +08:00
wanglinjie
1314e0ee09 feat(upload): add support for Aliyhun OSS 2025-09-03 09:38:01 +08:00
snaily
81d92370ad Merge pull request #351 from SquirrelJimmy/main 2025-09-03 03:27:45 +08:00
SquirrelJimmy
5f6eba62cc feat: 增加配置页面的picgo 自定义url, 处理自定义picgo的返回结果 2025-09-01 11:52:12 +08:00
snaily
a8a265c2a7 chore(docker): 注释掉 adminer 服务
暂时移除 Adminer 服务,因为它目前不是必需的,并且可以减少运行的容器数量,简化本地开发环境。
2025-08-31 22:00:36 +08:00
snaily
ee21e50305 Merge pull request #303 from vickyyd:main
Add adminer for convenient mysql database management
2025-08-31 21:58:34 +08:00
snaily
611559d298 feat(image): 支持多模态模型输入base64格式图片
- 在消息转换中,增加对 `data:image/png;base64,...` 格式图片的支持,允许用户直接在输入中提供base64编码的图片。
- 调整图片处理逻辑,使其能够根据模型名称判断是否启用多模态能力,避免非多模态模型错误处理图片链接。
- 当未配置图床时,模型输出的图片将回退为base64格式,确保图片内容始终可用。
- 优化了相关函数的参数传递和代码格式,提高了代码的可读性和健壮性。
2025-08-31 21:39:12 +08:00
snaily
b0127e6fc2 Merge pull request #344 from bbbugg/base64-fallback 2025-08-31 05:37:19 +08:00
bbbugg
1d15a21ce5 Remove upload folder check for Cloudflare imgbed 2025-08-31 00:33:33 +08:00
snaily
c206aa8e4a Merge pull request #316 from ConstasJ/modify-readme
docs: 在README里添加关于端点的说明
2025-08-31 00:24:53 +08:00
SquirrelJimmy
3f040b7075 feat: 增加自定义picgo api url 2025-08-29 12:49:12 +08:00
Copilot
1771555fe9 Add final SSE error events for streaming endpoints when retries are exhausted
* Initial plan

* Add final SSE error events for all streaming services

Co-authored-by: bbbugg <80089841+bbbugg@users.noreply.github.com>

* revert openai

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: bbbugg <80089841+bbbugg@users.noreply.github.com>
Co-authored-by: bbbugg <daming20120101@163.com>

Enhance error handling by extracting nested JSON from error messages in SSE events

Enhance error handling for content generation and streaming endpoints

Enhance error handling for content generation and streaming endpoints

Enhance error handling for content generation and streaming endpoints

Enhance error handling for content generation and streaming endpoints

还原vertex和openai的更改,只保留gemini
2025-08-28 22:10:32 +08:00
Copilot
8711088ebc Fix circular import issue between config, logger, and helpers modules (#2)
* Initial plan

* Fix circular import by removing top-level settings import from helpers.py

Co-authored-by: bbbugg <80089841+bbbugg@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: bbbugg <80089841+bbbugg@users.noreply.github.com>
2025-08-28 16:56:04 +08:00
Copilot
bb6c629aef Implement base64 fallback for image handling when no uploader is configured (#1)
* Initial plan

* Implement base64 fallback for image handling when no uploader configured

Co-authored-by: bbbugg <80089841+bbbugg@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: bbbugg <80089841+bbbugg@users.noreply.github.com>
2025-08-28 16:56:04 +08:00
ConstasJ
d06e418a61 docs: 修改README.md,添加关于端点的说明 2025-08-18 15:42:46 +08:00
kikii16
fd39c2c9cb Add adminer for convenient mysql database management 2025-08-13 02:21:23 +08:00
40 changed files with 2009 additions and 912 deletions

View File

@@ -14,11 +14,11 @@ AUTH_TOKEN=sk-123456
VERTEX_API_KEYS=["AQ.Abxxxxxxxxxxxxxxxxxxx"]
# For Vertex AI Platform Express API Base URL
VERTEX_EXPRESS_BASE_URL=https://aiplatform.googleapis.com/v1beta1/publishers/google
TEST_MODEL=gemini-1.5-flash
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
IMAGE_MODELS=["gemini-2.0-flash-exp"]
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TEST_MODEL=gemini-2.5-flash-lite
THINKING_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
THINKING_BUDGET_MAP={"gemini-2.5-flash": -1}
IMAGE_MODELS=["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
SEARCH_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
# 是否启用网址上下文,默认启用
URL_CONTEXT_ENABLED=false
@@ -44,9 +44,17 @@ CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
PICGO_API_URL=https://www.picgo.net/api/1/upload
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
# 阿里云OSS配置
OSS_ENDPOINT=oss-cn-shanghai.aliyuncs.com
OSS_ENDPOINT_INNER=oss-cn-shanghai-internal.aliyuncs.com
OSS_ACCESS_KEY=LTAI5txxxxxxxxxxxxxxxx
OSS_ACCESS_KEY_SECRET=yXxxxxxxxxxxxxxxxxxxxxx
OSS_BUCKET_NAME=your-bucket-name
OSS_REGION=cn-shanghai
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_OPTIMIZER_ENABLED=false
@@ -59,6 +67,8 @@ STREAM_CHUNK_SIZE=5
######################### 日志配置 #######################################
# 日志级别 (debug, info, warning, error, critical),默认为 info
LOG_LEVEL=info
# 是否记录错误日志的请求体(可能包含敏感信息),默认 false
ERROR_LOG_RECORD_REQUEST_BODY=false
# 是否开启自动删除错误日志
AUTO_DELETE_ERROR_LOGS_ENABLED=true
# 自动删除多少天前的错误日志 (1, 7, 30)

View File

@@ -8,6 +8,9 @@ COPY ./VERSION /app
RUN pip install --no-cache-dir -r requirements.txt
COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV TZ='Asia/Shanghai'
# Expose port
EXPOSE 8000

View File

@@ -137,6 +137,8 @@ app/
### Gemini API Format (`/gemini/v1beta`)
This endpoint is directly forwarded to official Gemini API format endpoint, without advanced features.
* `GET /models`: List available Gemini models.
* `POST /models/{model_name}:generateContent`: Generate content.
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation.
@@ -145,6 +147,8 @@ app/
#### Hugging Face (HF) Compatible
If you want to use advanced features, like fake streaming, please use this endpoint.
* `GET /hf/v1/models`: List models.
* `POST /hf/v1/chat/completions`: Chat completion.
* `POST /hf/v1/embeddings`: Create text embeddings.
@@ -152,6 +156,8 @@ app/
#### Standard OpenAI
This endpoint is directly forwarded to official OpenAI Compatible API format endpoint, without advanced features.
* `GET /openai/v1/models`: List models.
* `POST /openai/v1/chat/completions`: Chat completion (Recommended).
* `POST /openai/v1/embeddings`: Create text embeddings.
@@ -178,9 +184,9 @@ app/
| `ALLOWED_TOKENS` | **Required**, list of access tokens | `[]` |
| `AUTH_TOKEN` | Super admin token, defaults to the first of `ALLOWED_TOKENS` | `sk-123456` |
| `ADMIN_SESSION_EXPIRE` | Admin session expiration time in seconds (5 minutes to 24 hours) | `3600` |
| `TEST_MODEL` | Model for testing key validity | `gemini-1.5-flash` |
| `IMAGE_MODELS` | Models supporting image generation | `["gemini-2.0-flash-exp"]` |
| `SEARCH_MODELS` | Models supporting web search | `["gemini-2.0-flash-exp"]` |
| `TEST_MODEL` | Model for testing key validity | `gemini-2.5-flash-lite` |
| `IMAGE_MODELS` | Models supporting image generation | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
| `SEARCH_MODELS` | Models supporting web search | `["gemini-2.5-flash","gemini-2.5-pro"]` |
| `FILTERED_MODELS` | Disabled models | `[]` |
| `TOOLS_CODE_EXECUTION_ENABLED` | Enable code execution tool | `false` |
| `SHOW_SEARCH_LINK` | Display search result links in response | `true` |
@@ -199,6 +205,7 @@ app/
| `PROXIES` | List of proxy servers | `[]` |
| **Logging & Security** | | |
| `LOG_LEVEL` | Log level: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
| `ERROR_LOG_RECORD_REQUEST_BODY` | Record request body in error logs (may contain sensitive information) | `false` |
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | Auto-delete error logs | `true` |
| `AUTO_DELETE_ERROR_LOGS_DAYS` | Error log retention period (days) | `7` |
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Auto-delete request logs | `false` |
@@ -211,9 +218,16 @@ app/
| **Image Generation** | | |
| `PAID_KEY` | Paid API Key for advanced features | `your-paid-api-key` |
| `CREATE_IMAGE_MODEL` | Image generation model | `imagen-3.0-generate-002` |
| `UPLOAD_PROVIDER` | Image upload provider: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
| `UPLOAD_PROVIDER` | Image upload provider: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
| `OSS_ENDPOINT` | Aliyun OSS public endpoint | `oss-cn-shanghai.aliyuncs.com` |
| `OSS_ENDPOINT_INNER` | Aliyun OSS internal endpoint (intra-VPC) | `oss-cn-shanghai-internal.aliyuncs.com` |
| `OSS_ACCESS_KEY` | Aliyun AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
| `OSS_ACCESS_KEY_SECRET` | Aliyun AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
| `OSS_BUCKET_NAME` | Aliyun OSS bucket name | `your-bucket-name` |
| `OSS_REGION` | Aliyun OSS region | `cn-shanghai` |
| `SMMS_SECRET_TOKEN` | SM.MS API Token | `your-smms-token` |
| `PICGO_API_KEY` | PicoGo API Key | `your-picogo-apikey` |
| `PICGO_API_URL` | PicoGo API Server URL | `https://www.picgo.net/api/1/upload` |
| `CLOUDFLARE_IMGBED_URL` | CloudFlare ImgBed upload URL | `https://xxxxxxx.pages.dev/upload` |
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare ImgBed auth key | `your-cloudflare-imgber-auth-code` |
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare ImgBed upload folder | `""` |

View File

@@ -138,6 +138,8 @@ app/
### Gemini API 格式 (`/gemini/v1beta`)
此端点将请求直接转发到官方 Gemini API 格式的端点,不包含高级功能。
* `GET /models`: 列出可用的 Gemini 模型。
* `POST /models/{model_name}:generateContent`: 生成内容。
* `POST /models/{model_name}:streamGenerateContent`: 流式生成内容。
@@ -146,6 +148,8 @@ app/
#### 兼容 huggingface (HF) 格式
如果您需要使用高级功能(例如假流式输出),请使用此端点。
* `GET /hf/v1/models`: 列出模型。
* `POST /hf/v1/chat/completions`: 聊天补全。
* `POST /hf/v1/embeddings`: 创建文本嵌入。
@@ -153,6 +157,8 @@ app/
#### 标准 OpenAI 格式
此端点直接转发至官方的 OpenAI 兼容 API 格式端点,不包含高级功能。
* `GET /openai/v1/models`: 列出模型。
* `POST /openai/v1/chat/completions`: 聊天补全 (推荐,速度更快,防截断)。
* `POST /openai/v1/embeddings`: 创建文本嵌入。
@@ -178,9 +184,9 @@ app/
| `API_KEYS` | **必填**, Gemini API 密钥列表,用于负载均衡 | `[]` |
| `ALLOWED_TOKENS` | **必填**, 允许访问的 Token 列表 | `[]` |
| `AUTH_TOKEN` | 超级管理员 Token不填则使用 `ALLOWED_TOKENS` 的第一个 | `sk-123456` |
| `TEST_MODEL` | 用于测试密钥可用性的模型 | `gemini-1.5-flash` |
| `IMAGE_MODELS` | 支持绘图功能的模型列表 | `["gemini-2.0-flash-exp"]` |
| `SEARCH_MODELS` | 支持搜索功能的模型列表 | `["gemini-2.0-flash-exp"]` |
| `TEST_MODEL` | 用于测试密钥可用性的模型 | `gemini-2.5-flash-lite` |
| `IMAGE_MODELS` | 支持绘图功能的模型列表 | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
| `SEARCH_MODELS` | 支持搜索功能的模型列表 | `["gemini-2.5-flash","gemini-2.5-pro"]` |
| `FILTERED_MODELS` | 被禁用的模型列表 | `[]` |
| `TOOLS_CODE_EXECUTION_ENABLED` | 是否启用代码执行工具 | `false` |
| `SHOW_SEARCH_LINK` | 是否在响应中显示搜索结果链接 | `true` |
@@ -199,6 +205,7 @@ app/
| `PROXIES` | 代理服务器列表 (例如 `http://user:pass@host:port`) | `[]` |
| **日志与安全** | | |
| `LOG_LEVEL` | 日志级别: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
| `ERROR_LOG_RECORD_REQUEST_BODY` | 是否记录错误日志的请求体(可能包含敏感信息) | `false` |
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 是否自动删除错误日志 | `true` |
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 错误日志保留天数 | `7` |
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 是否自动删除请求日志 | `false` |
@@ -211,9 +218,16 @@ app/
| **图像生成相关** | | |
| `PAID_KEY` | 付费版API Key用于图片生成等高级功能 | `your-paid-api-key` |
| `CREATE_IMAGE_MODEL` | 图片生成模型 | `imagen-3.0-generate-002` |
| `UPLOAD_PROVIDER` | 图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
| `UPLOAD_PROVIDER` | 图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
| `OSS_ENDPOINT` | 阿里云 OSS 公网 Endpoint | `oss-cn-shanghai.aliyuncs.com` |
| `OSS_ENDPOINT_INNER` | 阿里云 OSS 内网 Endpoint同 VPC 内网访问) | `oss-cn-shanghai-internal.aliyuncs.com` |
| `OSS_ACCESS_KEY` | 阿里云 AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
| `OSS_ACCESS_KEY_SECRET` | 阿里云 AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
| `OSS_BUCKET_NAME` | 阿里云 OSS Bucket 名称 | `your-bucket-name` |
| `OSS_REGION` | 阿里云 OSS 区域 Region | `cn-shanghai` |
| `SMMS_SECRET_TOKEN` | SM.MS图床的API Token | `your-smms-token` |
| `PICGO_API_KEY` | [PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
| `PICGO_API_URL` | [PicoGo](https://www.picgo.net/)图床的API服务器地址 | `https://www.picgo.net/api/1/upload` |
| `CLOUDFLARE_IMGBED_URL` | [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare图床的上传文件夹路径 | `""` |

View File

@@ -1 +1 @@
2.2.5
2.2.8

View File

@@ -6,7 +6,7 @@ import datetime
import json
from typing import Any, Dict, List, Type, get_args, get_origin
from pydantic import ValidationError, ValidationInfo, field_validator, Field
from pydantic import Field, ValidationError, ValidationInfo, field_validator
from pydantic_settings import BaseSettings
from sqlalchemy import insert, select, update
@@ -51,8 +51,8 @@ class Settings(BaseSettings):
return v
# API相关配置
API_KEYS: List[str]=[]
ALLOWED_TOKENS: List[str]=[]
API_KEYS: List[str] = []
ALLOWED_TOKENS: List[str] = []
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
@@ -62,7 +62,9 @@ class Settings(BaseSettings):
PROXIES: List[str] = []
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
VERTEX_API_KEYS: List[str] = []
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
VERTEX_EXPRESS_BASE_URL: str = (
"https://aiplatform.googleapis.com/v1beta1/publishers/google"
)
# 智能路由配置
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
@@ -71,13 +73,19 @@ class Settings(BaseSettings):
CUSTOM_HEADERS: Dict[str, str] = {}
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
SEARCH_MODELS: List[str] = ["gemini-2.5-flash", "gemini-2.5-pro"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
TOOLS_CODE_EXECUTION_ENABLED: bool = False
# 是否启用网址上下文
URL_CONTEXT_ENABLED: bool = False
URL_CONTEXT_MODELS: List[str] = ["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
URL_CONTEXT_MODELS: List[str] = [
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-2.0-flash",
"gemini-2.0-flash-live-001",
]
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
THINKING_MODELS: List[str] = []
@@ -94,9 +102,17 @@ class Settings(BaseSettings):
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
PICGO_API_URL: str = "https://www.picgo.net/api/1/upload"
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
# 阿里云OSS配置
OSS_ENDPOINT: str = ""
OSS_ENDPOINT_INNER: str = ""
OSS_ACCESS_KEY: str = ""
OSS_ACCESS_KEY_SECRET: str = ""
OSS_BUCKET_NAME: str = ""
OSS_REGION: str = ""
# 流式输出优化器配置
STREAM_OPTIMIZER_ENABLED: bool = False
@@ -120,6 +136,7 @@ class Settings(BaseSettings):
# 日志配置
LOG_LEVEL: str = "INFO"
ERROR_LOG_RECORD_REQUEST_BODY: bool = False
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
@@ -136,7 +153,7 @@ class Settings(BaseSettings):
default=3600,
ge=300,
le=86400,
description="Admin session expiration time in seconds (5 minutes to 24 hours)"
description="Admin session expiration time in seconds (5 minutes to 24 hours)",
)
def __init__(self, **kwargs):
@@ -168,7 +185,9 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
if isinstance(parsed, list):
return [str(item) for item in parsed]
except json.JSONDecodeError:
return [item.strip() for item in db_value.split(",") if item.strip()]
return [
item.strip() for item in db_value.split(",") if item.strip()
]
logger.warning(
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
)
@@ -220,7 +239,9 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
)
except json.JSONDecodeError:
logger.error(f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict.")
logger.error(
f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict."
)
return parsed_dict
# 处理 Dict[str, float]
elif args and args == (str, float):
@@ -242,7 +263,9 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
corrected_db_value = db_value.replace("'", '"')
parsed = json.loads(corrected_db_value)
if isinstance(parsed, dict):
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
parsed_dict = {
str(k): float(v) for k, v in parsed.items()
}
else:
logger.warning(
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
@@ -403,9 +426,7 @@ async def sync_initial_settings():
# 序列化值为字符串或 JSON 字符串
if isinstance(value, (list, dict)):
db_value = json.dumps(
value, ensure_ascii=False
)
db_value = json.dumps(value, ensure_ascii=False)
elif isinstance(value, bool):
db_value = str(value).lower()
elif value is None:

View File

@@ -9,7 +9,7 @@ MAX_RETRIES = 3 # 最大重试次数
# 模型相关常量
SUPPORTED_ROLES = ["user", "model", "system"]
DEFAULT_MODEL = "gemini-1.5-flash"
DEFAULT_MODEL = "gemini-2.5-flash-lite"
DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 8192
DEFAULT_TOP_P = 0.9
@@ -27,7 +27,7 @@ DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# 上传提供商
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed", "aliyun_oss"]
DEFAULT_UPLOAD_PROVIDER = "smms"
# 流式输出相关常量

View File

@@ -2,6 +2,7 @@
数据库服务模块
"""
import asyncio
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Union
@@ -122,16 +123,19 @@ async def add_error_log(
bool: 是否添加成功
"""
try:
# 如果request_msg是字典则转换为JSON字符串
if isinstance(request_msg, dict):
request_msg_json = request_msg
elif isinstance(request_msg, str):
try:
request_msg_json = json.loads(request_msg)
except json.JSONDecodeError:
request_msg_json = {"message": request_msg}
else:
if request_msg is None:
request_msg_json = None
else:
# 如果request_msg是字典则转换为JSON字符串
if isinstance(request_msg, dict):
request_msg_json = request_msg
elif isinstance(request_msg, str):
try:
request_msg_json = json.loads(request_msg)
except json.JSONDecodeError:
request_msg_json = {"message": request_msg}
else:
request_msg_json = None
# 插入错误日志
query = insert(ErrorLog).values(
@@ -446,24 +450,50 @@ async def delete_error_log_by_id(log_id: int) -> bool:
async def delete_all_error_logs() -> int:
"""
删除所有错误日志条目
分批删除所有错误日志,以避免大数据量下的超时和性能问题
Returns:
int: 被删除的错误日志数量。如果使用的数据库驱动不支持返回受影响行数,则返回 -1 表示操作成功
int: 被删除的错误日志数。
"""
total_deleted_count = 0
# SQLite 对 SQL 参数数量有上限(常见为 999IN 子句中过多参数会报错
# 统一使用 500兼容 SQLite/MySQL必要时可在配置中暴露该值
batch_size = 200
try:
# 直接执行删除操作,避免不必要的查询
delete_query = delete(ErrorLog)
await database.execute(delete_query)
while True:
# 1) 读取一批待删除的ID仅选择ID列以提升效率
id_query = select(ErrorLog.id).order_by(ErrorLog.id).limit(batch_size)
rows = await database.fetch_all(id_query)
if not rows:
break
logger.info("Successfully deleted all error logs.")
ids = [row["id"] for row in rows]
# 由于 databases 库的 execute 方法不返回受影响的行数,
# 返回 -1 表示删除操作成功执行,但具体删除数量未知
# 这比先查询再删除的方式更高效
return -1
# 2) 按ID批量删除
delete_query = delete(ErrorLog).where(ErrorLog.id.in_(ids))
await database.execute(delete_query)
deleted_in_batch = len(ids)
total_deleted_count += deleted_in_batch
logger.debug(f"Deleted a batch of {deleted_in_batch} error logs.")
# 若不足一个批次,说明已删除完成
if deleted_in_batch < batch_size:
break
# 3) 将控制权交还事件循环,缓解长时间占用
await asyncio.sleep(0)
logger.info(
f"Successfully deleted all error logs in batches. Total deleted: {total_deleted_count}"
)
return total_deleted_count
except Exception as e:
logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True)
logger.error(
f"Failed to delete all error logs in batches: {str(e)}", exc_info=True
)
raise

View File

@@ -131,10 +131,5 @@ def setup_exception_handlers(app: FastAPI) -> None:
logger.exception(f"Unhandled Exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={
"error": {
"code": "internal_server_error",
"message": "An unexpected error occurred",
}
},
content=str(exc),
)

View File

@@ -27,7 +27,7 @@ class MessageConverter(ABC):
@abstractmethod
def convert(
self, messages: List[Dict[str, Any]]
self, messages: List[Dict[str, Any]], model: str
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
@@ -84,7 +84,7 @@ def _convert_image_to_base64(url: str) -> str:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
def _process_text_with_image(text: str, model: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
@@ -94,17 +94,31 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
# 如果模型名中没有包含image当作普通文本处理
if "image" not in model:
return [{"text": text}]
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(2)
# 将URL对应的图片转换为base64
# 先判断是否是base64url如果是直接用不过不是将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append(
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
)
base64_url_match = re.search(DATA_URL_PATTERN, img_url)
if base64_url_match:
parts.append(
{
"inline_data": {
"mimeType": base64_url_match.group(1),
"data": base64_url_match.group(2),
}
}
)
else:
base64_data = _convert_image_to_base64(img_url)
parts.append(
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
)
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
@@ -145,7 +159,7 @@ class OpenAIMessageConverter(MessageConverter):
raise
def convert(
self, messages: List[Dict[str, Any]]
self, messages: List[Dict[str, Any]], model: str
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction_parts = []
@@ -296,7 +310,7 @@ class OpenAIMessageConverter(MessageConverter):
elif (
"content" in msg and isinstance(msg["content"], str) and msg["content"]
):
parts.extend(_process_text_with_image(msg["content"]))
parts.extend(_process_text_with_image(msg["content"], model))
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
# Keep existing tool call processing
for tool_call in msg["tool_calls"]:

View File

@@ -8,8 +8,9 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
from app.log.logger import get_openai_logger
from app.utils.helpers import is_image_upload_configured
from app.utils.uploader import ImageUploaderFactory
logger = get_openai_logger()
@@ -32,7 +33,11 @@ class GeminiResponseHandler(ResponseHandler):
self.thinking_status = False
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
usage_metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
if stream:
return _handle_gemini_stream_response(response, model, stream)
@@ -40,7 +45,10 @@ class GeminiResponseHandler(ResponseHandler):
def _handle_openai_stream_response(
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
response: Dict[str, Any],
model: str,
finish_reason: str,
usage_metadata: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
choices = []
candidates = response.get("candidates", [])
@@ -54,15 +62,15 @@ def _handle_openai_stream_response(
if not text and not tool_calls and not reasoning_content:
delta = {}
else:
delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"}
delta = {
"content": text,
"reasoning_content": reasoning_content,
"role": "assistant",
}
if tool_calls:
delta["tool_calls"] = tool_calls
choice = {
"index": index,
"delta": delta,
"finish_reason": finish_reason
}
choice = {"index": index, "delta": delta, "finish_reason": finish_reason}
choices.append(choice)
template_chunk = {
@@ -73,16 +81,23 @@ def _handle_openai_stream_response(
"choices": choices,
}
if usage_metadata:
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
template_chunk["usage"] = {
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
"total_tokens": usage_metadata.get("totalTokenCount", 0),
}
return template_chunk
def _handle_openai_normal_response(
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
response: Dict[str, Any],
model: str,
finish_reason: str,
usage_metadata: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
choices = []
candidates = response.get("candidates", [])
for i, candidate in enumerate(candidates):
text, reasoning_content, tool_calls, _ = _extract_result(
{"candidates": [candidate]}, model, stream=False, gemini_format=False
@@ -105,7 +120,11 @@ def _handle_openai_normal_response(
"created": int(time.time()),
"model": model,
"choices": choices,
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
"usage": {
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
"total_tokens": usage_metadata.get("totalTokenCount", 0),
},
}
@@ -126,8 +145,12 @@ class OpenAIResponseHandler(ResponseHandler):
usage_metadata: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
if stream:
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
return _handle_openai_stream_response(
response, model, finish_reason, usage_metadata
)
return _handle_openai_normal_response(
response, model, finish_reason, usage_metadata
)
def handle_image_chat_response(
self, image_str: str, model: str, stream=False, finish_reason="stop"
@@ -181,7 +204,7 @@ def _extract_result(
gemini_format: bool = False,
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
text, reasoning_content, tool_calls, thought = "", "", [], None
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
@@ -190,7 +213,7 @@ def _extract_result(
if not parts:
logger.warning("No parts found in stream response")
return "", None, [], None
if "text" in parts[0]:
text = parts[0].get("text")
if "thought" in parts[0]:
@@ -216,13 +239,13 @@ def _extract_result(
if response.get("candidates"):
candidate = response["candidates"][0]
text, reasoning_content = "", ""
# 使用安全的访问方式
content = candidate.get("content", {})
if content and isinstance(content, dict):
parts = content.get("parts", [])
if parts:
for part in parts:
if "text" in part:
@@ -240,17 +263,28 @@ def _extract_result(
logger.error(f"Invalid content structure for model: {model}")
text = _add_search_link_text(model, candidate, text)
# 安全地获取 parts 用于工具调用提取
parts = candidate.get("content", {}).get("parts", [])
tool_calls = _extract_tool_calls(parts, gemini_format)
else:
logger.warning(f"No candidates found in response for model: {model}")
text = "暂无返回"
return text, reasoning_content, tool_calls, thought
def _has_inline_image_part(response: Dict[str, Any]) -> bool:
try:
for c in response.get("candidates", []):
for p in c.get("content", {}).get("parts", []):
if isinstance(p, dict) and ("inlineData" in p):
return True
except Exception:
return False
return False
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
@@ -259,7 +293,9 @@ def _extract_image_data(part: dict) -> str:
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY,
api_url=settings.PICGO_API_URL
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
@@ -268,16 +304,30 @@ def _extract_image_data(part: dict) -> str:
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
)
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
access_key=settings.OSS_ACCESS_KEY,
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
bucket_name=settings.OSS_BUCKET_NAME,
endpoint=settings.OSS_ENDPOINT,
region=settings.OSS_REGION,
use_internal=False
)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
mime_type = part["inlineData"]["mimeType"]
# 将base64_data转成bytes数组
# Return empty string if no uploader is configured
if not is_image_upload_configured(settings):
return f"\n\n![image](data:{mime_type};base64,{base64_data})\n\n"
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data, filename)
if upload_response.success:
text = f"\n\n![image]({upload_response.data.url})\n\n"
else:
text = ""
text = f"\n\n![image](data:{mime_type};base64,{base64_data})\n\n"
return text
@@ -290,7 +340,7 @@ def _extract_tool_calls(
letters = string.ascii_lowercase + string.digits
tool_calls = list()
for i in range(len(parts)):
part = parts[i]
if not part or not isinstance(part, dict):
@@ -299,7 +349,7 @@ def _extract_tool_calls(
item = part.get("functionCall", {})
if not item or not isinstance(item, dict):
continue
if gemini_format:
tool_calls.append(part)
else:
@@ -322,6 +372,10 @@ def _extract_tool_calls(
def _handle_gemini_stream_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
# Early return raw Gemini response if no uploader configured and contains inline images
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
return response
text, reasoning_content, tool_calls, thought = _extract_result(
response, model, stream=stream, gemini_format=True
)
@@ -339,6 +393,10 @@ def _handle_gemini_stream_response(
def _handle_gemini_normal_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
# Early return raw Gemini response if no uploader configured and contains inline images
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
return response
text, reasoning_content, tool_calls, thought = _extract_result(
response, model, stream=stream, gemini_format=True
)
@@ -347,7 +405,7 @@ def _handle_gemini_normal_response(
parts = tool_calls
else:
if thought is not None:
parts.append({"text": reasoning_content,"thought": thought})
parts.append({"text": reasoning_content, "thought": thought})
part = {"text": text}
parts.append(part)
content = {"parts": parts, "role": "model"}

View File

@@ -1,9 +1,8 @@
import logging
import platform
import sys
import re
import sys
from typing import Dict, Optional
from app.utils.helpers import redact_key_for_logging as _redact_key_for_logging
# ANSI转义序列颜色代码
COLORS = {
@@ -15,7 +14,6 @@ COLORS = {
}
# Windows系统启用ANSI支持
if platform.system() == "Windows":
import ctypes
@@ -46,14 +44,16 @@ class AccessLogFormatter(logging.Formatter):
# API key patterns to match in URLs
API_KEY_PATTERNS = [
r'\bAIza[0-9A-Za-z_-]{35}', # Google API keys (like Gemini)
r'\bsk-[0-9A-Za-z_-]{20,}', # OpenAI and general sk- prefixed keys
r"\bAIza[0-9A-Za-z_-]{35}", # Google API keys (like Gemini)
r"\bsk-[0-9A-Za-z_-]{20,}", # OpenAI and general sk- prefixed keys
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Compile regex patterns for better performance
self.compiled_patterns = [re.compile(pattern) for pattern in self.API_KEY_PATTERNS]
self.compiled_patterns = [
re.compile(pattern) for pattern in self.API_KEY_PATTERNS
]
def format(self, record):
# Format the record normally first
@@ -68,9 +68,10 @@ class AccessLogFormatter(logging.Formatter):
"""
try:
for pattern in self.compiled_patterns:
def replace_key(match):
key = match.group(0)
return _redact_key_for_logging(key)
return redact_key_for_logging(key)
message = pattern.sub(replace_key, message)
@@ -78,11 +79,31 @@ class AccessLogFormatter(logging.Formatter):
except Exception as e:
# Log the error but don't expose the original message in case it contains keys
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error redacting API keys in access log: {e}")
return "[LOG_REDACTION_ERROR]"
def redact_key_for_logging(key: str) -> str:
"""
Redacts API key for secure logging by showing only first and last 6 characters.
Args:
key: API key to redact
Returns:
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
"""
if not key:
return key
if len(key) <= 12:
return f"{key[:3]}...{key[-3:]}"
else:
return f"{key[:6]}...{key[-6:]}"
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
FORMATTER = ColoredFormatter(
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
@@ -326,4 +347,3 @@ def setup_access_logging():
access_logger.propagate = False
return access_logger

View File

@@ -1,19 +1,28 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
import asyncio
from copy import deepcopy
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from app.config.config import settings
from app.log.logger import get_gemini_logger
from app.core.constants import API_VERSION
from app.core.security import SecurityService
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest, GeminiEmbedRequest, GeminiBatchEmbedRequest
from app.domain.gemini_models import (
GeminiBatchEmbedRequest,
GeminiContent,
GeminiEmbedRequest,
GeminiRequest,
ResetSelectedKeysRequest,
VerifySelectedKeysRequest,
)
from app.handler.error_handler import handle_route_errors
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_gemini_logger
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.tts.native.tts_routes import get_tts_chat_service
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.core.constants import API_VERSION
from app.service.tts.native.tts_routes import get_tts_chat_service
from app.utils.helpers import redact_key_for_logging
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
@@ -47,8 +56,8 @@ async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manage
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(
_=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager)
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
operation_name = "list_gemini_models"
@@ -58,20 +67,30 @@ async def list_models(
try:
api_key = await key_manager.get_random_valid_key()
if not api_key:
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
raise HTTPException(
status_code=503, detail="No valid API keys available to fetch models."
)
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
models_data = await model_service.get_gemini_models(api_key)
if not models_data or "models" not in models_data:
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
raise HTTPException(
status_code=500, detail="Failed to fetch base models list."
)
models_json = deepcopy(models_data)
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
model_mapping = {
x.get("name", "").split("/", maxsplit=1)[-1]: x
for x in models_json.get("models", [])
}
def add_derived_model(base_name, suffix, display_suffix):
model = model_mapping.get(base_name)
if not model:
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
logger.warning(
f"Base model '{base_name}' not found for derived model '{suffix}'."
)
return
item = deepcopy(model)
item["name"] = f"models/{base_name}{suffix}"
@@ -85,7 +104,7 @@ async def list_models(
add_derived_model(name, "-search", " For Search")
if settings.IMAGE_MODELS:
for name in settings.IMAGE_MODELS:
add_derived_model(name, "-image", " For Image")
add_derived_model(name, "-image", " For Image")
if settings.THINKING_MODELS:
for name in settings.THINKING_MODELS:
add_derived_model(name, "-non-thinking", " Non Thinking")
@@ -97,7 +116,8 @@ async def list_models(
except Exception as e:
logger.error(f"Error getting Gemini models list: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error while fetching Gemini models list"
status_code=500,
detail="Internal server error while fetching Gemini models list",
) from e
@@ -107,15 +127,19 @@ async def list_models(
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
chat_service: GeminiChatService = Depends(get_chat_service),
):
"""处理 Gemini 非流式内容生成请求。"""
operation_name = "gemini_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
logger.info(f"Handling Gemini content generation request for model: {model_name}")
async with handle_route_errors(
logger, operation_name, failure_message="Content generation failed"
):
logger.info(
f"Handling Gemini content generation request for model: {model_name}"
)
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
# 检测是否为原生Gemini TTS请求
@@ -132,10 +156,13 @@ async def generate_content(
logger.info(f"TTS responseModalities: {response_modalities}")
logger.info(f"TTS speechConfig: {speech_config}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
# 所有原生TTS请求都使用TTS增强服务
if is_native_tts:
@@ -143,19 +170,17 @@ async def generate_content(
logger.info("Using native TTS enhanced service")
tts_service = await get_tts_chat_service(key_manager)
response = await tts_service.generate_content(
model=model_name,
request=request,
api_key=api_key
model=model_name, request=request, api_key=api_key
)
return response
except Exception as e:
logger.warning(f"Native TTS processing failed, falling back to standard service: {e}")
logger.warning(
f"Native TTS processing failed, falling back to standard service: {e}"
)
# 使用标准服务处理所有其他请求非TTS
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
model=model_name, request=request, api_key=api_key
)
return response
@@ -166,27 +191,53 @@ async def generate_content(
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
chat_service: GeminiChatService = Depends(get_chat_service),
):
"""处理 Gemini 流式内容生成请求。"""
operation_name = "gemini_stream_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
async with handle_route_errors(
logger, operation_name, failure_message="Streaming request initiation failed"
):
logger.info(
f"Handling Gemini streaming content generation for model: {model_name}"
)
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
raw_stream = chat_service.stream_generate_content(
model=model_name, request=request, api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
try:
# 尝试获取第一条数据,判断是正常 SSEdata: 前缀)还是错误 JSON
first_chunk = await raw_stream.__anext__()
except StopAsyncIteration:
# 如果流直接结束,退回标准 SSE 输出
return StreamingResponse(raw_stream, media_type="text/event-stream")
except Exception as e:
# 初始化流异常,直接返回 500 错误
return JSONResponse(
content={"error": {"code": e.args[0], "message": e.args[1]}},
status_code=e.args[0],
)
# 如果以 "data:" 开头,代表正常 SSE将首块和后续块一起发送
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
async def combined():
yield first_chunk
async for chunk in raw_stream:
yield chunk
return StreamingResponse(combined(), media_type="text/event-stream")
@router.post("/models/{model_name}:countTokens")
@@ -195,53 +246,60 @@ async def stream_generate_content(
async def count_tokens(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
chat_service: GeminiChatService = Depends(get_chat_service),
):
"""处理 Gemini token 计数请求。"""
operation_name = "gemini_count_tokens"
async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"):
async with handle_route_errors(
logger, operation_name, failure_message="Token counting failed"
):
logger.info(f"Handling Gemini token count request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
response = await chat_service.count_tokens(
model=model_name,
request=request,
api_key=api_key
model=model_name, request=request, api_key=api_key
)
return response
@router.post("/models/{model_name}:embedContent")
@router_v1beta.post("/models/{model_name}:embedContent")
@RetryHandler(key_arg="api_key")
async def embed_content(
model_name: str,
request: GeminiEmbedRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
):
"""处理 Gemini 单一嵌入请求"""
operation_name = "gemini_embed_content"
async with handle_route_errors(logger, operation_name, failure_message="Embedding content generation failed"):
async with handle_route_errors(
logger, operation_name, failure_message="Embedding content generation failed"
):
logger.info(f"Handling Gemini embedding request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
response = await embedding_service.embed_content(
model=model_name,
request=request,
api_key=api_key
model=model_name, request=request, api_key=api_key
)
return response
@@ -252,41 +310,48 @@ async def embed_content(
async def batch_embed_contents(
model_name: str,
request: GeminiBatchEmbedRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
):
"""处理 Gemini 批量嵌入请求"""
operation_name = "gemini_batch_embed_contents"
async with handle_route_errors(logger, operation_name, failure_message="Batch embedding content generation failed"):
async with handle_route_errors(
logger,
operation_name,
failure_message="Batch embedding content generation failed",
):
logger.info(f"Handling Gemini batch embedding request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
response = await embedding_service.batch_embed_contents(
model=model_name,
request=request,
api_key=api_key
model=model_name, request=request, api_key=api_key
)
return response
@router.post("/reset-all-fail-counts")
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
async def reset_all_key_fail_counts(
key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)
):
"""批量重置Gemini API密钥的失败计数可选择性地仅重置有效或无效密钥"""
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
logger.info(f"Received reset request with key_type: {key_type}")
try:
# 获取分类后的密钥
keys_by_status = await key_manager.get_keys_by_status()
valid_keys = keys_by_status.get("valid_keys", {})
invalid_keys = keys_by_status.get("invalid_keys", {})
# 根据类型选择要重置的密钥
keys_to_reset = []
if key_type == "valid":
@@ -298,35 +363,45 @@ async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManage
else:
# 重置所有密钥
await key_manager.reset_failure_counts()
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
return JSONResponse(
{"success": True, "message": "所有密钥的失败计数已重置"}
)
# 批量重置指定类型的密钥
for key in keys_to_reset:
await key_manager.reset_key_failure_count(key)
return JSONResponse({
"success": True,
"message": f"{key_type}密钥的失败计数已重置",
"reset_count": len(keys_to_reset)
})
return JSONResponse(
{
"success": True,
"message": f"{key_type}密钥的失败计数已重置",
"reset_count": len(keys_to_reset),
}
)
except Exception as e:
logger.error(f"Failed to reset key failure counts: {str(e)}")
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
return JSONResponse(
{"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500
)
@router.post("/reset-selected-fail-counts")
async def reset_selected_key_fail_counts(
request: ResetSelectedKeysRequest,
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
"""批量重置选定Gemini API密钥的失败计数"""
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
keys_to_reset = request.keys
key_type = request.key_type
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
logger.info(
f"Received reset request for {len(keys_to_reset)} selected {key_type} keys."
)
if not keys_to_reset:
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
return JSONResponse(
{"success": False, "message": "没有提供需要重置的密钥"}, status_code=400
)
reset_count = 0
errors = []
@@ -338,53 +413,79 @@ async def reset_selected_key_fail_counts(
if result:
reset_count += 1
else:
logger.warning(f"Key not found during selective reset: {redact_key_for_logging(key)}")
logger.warning(
f"Key not found during selective reset: {redact_key_for_logging(key)}"
)
except Exception as key_error:
logger.error(f"Error resetting key {redact_key_for_logging(key)}: {str(key_error)}")
logger.error(
f"Error resetting key {redact_key_for_logging(key)}: {str(key_error)}"
)
errors.append(f"Key {key}: {str(key_error)}")
if errors:
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
final_success = reset_count > 0
status_code = 207 if final_success and errors else 500
return JSONResponse({
"success": final_success,
"message": error_message,
"reset_count": reset_count
}, status_code=status_code)
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
final_success = reset_count > 0
status_code = 207 if final_success and errors else 500
return JSONResponse(
{
"success": final_success,
"message": error_message,
"reset_count": reset_count,
},
status_code=status_code,
)
return JSONResponse({
"success": True,
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
"reset_count": reset_count
})
return JSONResponse(
{
"success": True,
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
"reset_count": reset_count,
}
)
except Exception as e:
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
logger.error(
f"Failed to process reset selected key failure counts request: {str(e)}"
)
return JSONResponse(
{"success": False, "message": f"批量重置处理失败: {str(e)}"},
status_code=500,
)
@router.post("/reset-fail-count/{api_key}")
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
async def reset_key_fail_count(
api_key: str, key_manager: KeyManager = Depends(get_key_manager)
):
"""重置指定Gemini API密钥的失败计数"""
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
logger.info(f"Resetting failure count for API key: {redact_key_for_logging(api_key)}")
logger.info(
f"Resetting failure count for API key: {redact_key_for_logging(api_key)}"
)
try:
result = await key_manager.reset_key_failure_count(api_key)
if result:
return JSONResponse({"success": True, "message": "失败计数已重置"})
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
return JSONResponse(
{"success": False, "message": "未找到指定密钥"}, status_code=404
)
except Exception as e:
logger.error(f"Failed to reset key failure count: {str(e)}")
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
return JSONResponse(
{"success": False, "message": f"重置失败: {str(e)}"}, status_code=500
)
@router.post("/verify-key/{api_key}")
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
async def verify_key(
api_key: str,
chat_service: GeminiChatService = Depends(get_chat_service),
key_manager: KeyManager = Depends(get_key_manager),
):
"""验证Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
logger.info("Verifying API key validity")
try:
gemini_request = GeminiRequest(
contents=[
@@ -393,43 +494,47 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
parts=[{"text": "hi"}],
)
],
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10},
)
response = await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
settings.TEST_MODEL, gemini_request, api_key
)
if response:
# 如果密钥验证成功,则重置其失败计数
await key_manager.reset_key_failure_count(api_key)
return JSONResponse({"status": "valid"})
except Exception as e:
logger.error(f"Key verification failed: {str(e)}")
async with key_manager.failure_count_lock:
if api_key in key_manager.key_failure_counts:
key_manager.key_failure_counts[api_key] += 1
logger.warning(f"Verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count")
return JSONResponse({"status": "invalid", "error": str(e)})
logger.warning(
f"Verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
)
return JSONResponse({"status": "invalid", "error": e.args[1]})
@router.post("/verify-selected-keys")
async def verify_selected_keys(
request: VerifySelectedKeysRequest,
chat_service: GeminiChatService = Depends(get_chat_service),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
"""批量验证选定Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
keys_to_verify = request.keys
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
logger.info(
f"Received verification request for {len(keys_to_verify)} selected keys."
)
if not keys_to_verify:
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
return JSONResponse(
{"success": False, "message": "没有提供需要验证的密钥"}, status_code=400
)
successful_keys = []
failed_keys = {}
@@ -440,28 +545,36 @@ async def verify_selected_keys(
try:
gemini_request = GeminiRequest(
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
generation_config={
"temperature": 0.7,
"topP": 1.0,
"maxOutputTokens": 10,
},
)
await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
settings.TEST_MODEL, gemini_request, api_key
)
successful_keys.append(api_key)
# 如果密钥验证成功,则重置其失败计数
await key_manager.reset_key_failure_count(api_key)
return api_key, "valid", None
except Exception as e:
error_message = str(e)
logger.warning(f"Key verification failed for {redact_key_for_logging(api_key)}: {error_message}")
error_message = e.args[1]
logger.warning(
f"Key verification failed for {redact_key_for_logging(api_key)}: {error_message}"
)
async with key_manager.failure_count_lock:
if api_key in key_manager.key_failure_counts:
key_manager.key_failure_counts[api_key] += 1
logger.warning(f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count")
logger.warning(
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
)
else:
key_manager.key_failure_counts[api_key] = 1
logger.warning(f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, initializing failure count to 1")
failed_keys[api_key] = error_message
key_manager.key_failure_counts[api_key] = 1
logger.warning(
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, initializing failure count to 1"
)
failed_keys[api_key] = {"error_message": e.args[1], "error_code": e.args[0]}
return api_key, "invalid", error_message
tasks = [_verify_single_key(key) for key in keys_to_verify]
@@ -469,34 +582,37 @@ async def verify_selected_keys(
for result in results:
if isinstance(result, Exception):
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
elif result:
if not isinstance(result, Exception) and result:
key, status, error = result
elif isinstance(result, Exception):
logger.error(f"Task execution error during bulk verification: {result}")
logger.error(
f"An unexpected error occurred during bulk verification task: {result}"
)
valid_count = len(successful_keys)
invalid_count = len(failed_keys)
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
logger.info(
f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}"
)
if failed_keys:
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}"
return JSONResponse({
"success": True,
"message": message,
"successful_keys": successful_keys,
"failed_keys": failed_keys,
"valid_count": valid_count,
"invalid_count": invalid_count
})
return JSONResponse(
{
"success": True,
"message": message,
"successful_keys": successful_keys,
"failed_keys": failed_keys,
"valid_count": valid_count,
"invalid_count": invalid_count,
}
)
else:
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
return JSONResponse({
"success": True,
"message": message,
"successful_keys": successful_keys,
"failed_keys": {},
"valid_count": valid_count,
"invalid_count": 0
})
return JSONResponse(
{
"success": True,
"message": message,
"successful_keys": successful_keys,
"failed_keys": {},
"valid_count": valid_count,
"invalid_count": 0,
}
)

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from app.config.config import settings
from app.core.security import SecurityService
@@ -8,19 +8,21 @@ from app.domain.openai_models import (
EmbeddingRequest,
ImageGenerationRequest,
)
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_openai_compatible_logger
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
from app.service.openai_compatiable.openai_compatiable_service import (
OpenAICompatiableService,
)
from app.utils.helpers import redact_key_for_logging
router = APIRouter()
logger = get_openai_compatible_logger()
security_service = SecurityService()
async def get_key_manager():
return await get_key_manager_instance()
@@ -38,7 +40,7 @@ async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager))
@router.get("/openai/v1/models")
async def list_models(
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
@@ -47,6 +49,7 @@ async def list_models(
async with handle_route_errors(logger, operation_name):
logger.info("Handling models list request")
api_key = await key_manager.get_random_valid_key()
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
return await openai_service.get_models(api_key)
@@ -55,7 +58,7 @@ async def list_models(
@RetryHandler(key_arg="api_key")
async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
@@ -70,28 +73,56 @@ async def chat_completion(
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling chat completion request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
raw_response = None
if is_image_chat:
response = await openai_service.create_image_chat_completion(request, current_api_key)
return response
raw_response = await openai_service.create_image_chat_completion(
request, current_api_key
)
else:
response = await openai_service.create_chat_completion(request, current_api_key)
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
raw_response = await openai_service.create_chat_completion(
request, current_api_key
)
if request.stream:
try:
# 尝试获取第一条数据,判断是正常 SSEdata: 前缀)还是错误 JSON
first_chunk = await raw_response.__anext__()
except StopAsyncIteration:
# 如果流直接结束,退回标准 SSE 输出
return StreamingResponse(raw_response, media_type="text/event-stream")
except Exception as e:
# 初始化流异常,直接返回 500 错误
return JSONResponse(
content={"error": {"code": e.args[0], "message": e.args[1]}},
status_code=e.args[0],
)
# 如果以 "data:" 开头,代表正常 SSE将首块和后续块一起发送
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
async def combined():
yield first_chunk
async for chunk in raw_response:
yield chunk
return StreamingResponse(combined(), media_type="text/event-stream")
else:
return raw_response
@router.post("/openai/v1/images/generations")
async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
"""处理图像生成请求。"""
operation_name = "generate_image"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling image generation request for prompt: {request.prompt}")
logger.info(f"Using allowed token: {allowed_token}")
request.model = settings.CREATE_IMAGE_MODEL
return await openai_service.generate_images(request)
@@ -99,7 +130,7 @@ async def generate_image(
@router.post("/openai/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
@@ -108,6 +139,7 @@ async def embedding(
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
return await openai_service.create_embeddings(
input_text=request.input, model=request.model, api_key=api_key

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from app.config.config import settings
from app.core.security import SecurityService
@@ -9,15 +9,15 @@ from app.domain.openai_models import (
ImageGenerationRequest,
TTSRequest,
)
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.tts.tts_service import TTSService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.service.tts.tts_service import TTSService
from app.utils.helpers import redact_key_for_logging
router = APIRouter()
@@ -53,7 +53,7 @@ async def get_tts_service():
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
@@ -61,6 +61,7 @@ async def list_models(
async with handle_route_errors(logger, operation_name):
logger.info("Handling models list request")
api_key = await key_manager.get_random_valid_key()
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
return await model_service.get_gemini_openai_models(api_key)
@@ -70,7 +71,7 @@ async def list_models(
@RetryHandler(key_arg="api_key")
async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
@@ -85,6 +86,7 @@ async def chat_completion(
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling chat completion request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
if not await model_service.check_model_support(request.model):
@@ -92,28 +94,54 @@ async def chat_completion(
status_code=400, detail=f"Model {request.model} is not supported"
)
raw_response = None
if is_image_chat:
response = await chat_service.create_image_chat_completion(request, current_api_key)
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
raw_response = await chat_service.create_image_chat_completion(
request, current_api_key
)
else:
response = await chat_service.create_chat_completion(request, current_api_key)
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
raw_response = await chat_service.create_chat_completion(
request, current_api_key
)
if request.stream:
try:
# 尝试获取第一条数据,判断是正常 SSEdata: 前缀)还是错误 JSON
first_chunk = await raw_response.__anext__()
except StopAsyncIteration:
# 如果流直接结束,退回标准 SSE 输出
return StreamingResponse(raw_response, media_type="text/event-stream")
except Exception as e:
# 初始化流异常,直接返回 500 错误
return JSONResponse(
content={"error": {"code": e.args[0], "message": e.args[1]}},
status_code=e.args[0],
)
# 如果以 "data:" 开头,代表正常 SSE将首块和后续块一起发送
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
async def combined():
yield first_chunk
async for chunk in raw_response:
yield chunk
return StreamingResponse(combined(), media_type="text/event-stream")
else:
return raw_response
@router.post("/v1/images/generations")
@router.post("/hf/v1/images/generations")
async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
):
"""处理 OpenAI 图像生成请求。"""
operation_name = "generate_image"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling image generation request for prompt: {request.prompt}")
logger.info(f"Using allowed token: {allowed_token}")
response = image_create_service.generate_images(request)
return response
@@ -122,7 +150,7 @@ async def generate_image(
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
"""处理 OpenAI 文本嵌入请求。"""
@@ -130,6 +158,7 @@ async def embedding(
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
response = await embedding_service.create_embedding(
input_text=request.input, model=request.model, api_key=api_key
@@ -162,7 +191,7 @@ async def get_keys_list(
@router.post("/hf/v1/audio/speech")
async def text_to_speech(
request: TTSRequest,
_=Depends(security_service.verify_authorization),
allowed_token=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
tts_service: TTSService = Depends(get_tts_service),
):
@@ -171,6 +200,7 @@ async def text_to_speech(
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling TTS request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
audio_data = await tts_service.create_tts(request, api_key)
return Response(content=audio_data, media_type="audio/wav")

View File

@@ -24,10 +24,13 @@ from app.router import (
)
from app.service.key.key_manager import get_key_manager_instance
from app.service.stats.stats_service import StatsService
from app.utils.static_version import get_static_url
logger = get_routes_logger()
templates = Jinja2Templates(directory="app/templates")
# 设置模板全局变量
templates.env.globals["static_url"] = get_static_url
def setup_routers(app: FastAPI) -> None:

View File

@@ -1,16 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from copy import deepcopy
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from app.config.config import settings
from app.log.logger import get_vertex_express_logger
from app.core.constants import API_VERSION
from app.core.security import SecurityService
from app.domain.gemini_models import GeminiRequest
from app.handler.error_handler import handle_route_errors
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_vertex_express_logger
from app.service.chat.vertex_express_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.core.constants import API_VERSION
from app.utils.helpers import redact_key_for_logging
router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
@@ -37,8 +39,8 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
@router.get("/models")
async def list_models(
_=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager)
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
operation_name = "list_gemini_models"
@@ -48,20 +50,30 @@ async def list_models(
try:
api_key = await key_manager.get_random_valid_key()
if not api_key:
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
raise HTTPException(
status_code=503, detail="No valid API keys available to fetch models."
)
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
models_data = await model_service.get_gemini_models(api_key)
if not models_data or "models" not in models_data:
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
raise HTTPException(
status_code=500, detail="Failed to fetch base models list."
)
models_json = deepcopy(models_data)
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
model_mapping = {
x.get("name", "").split("/", maxsplit=1)[-1]: x
for x in models_json.get("models", [])
}
def add_derived_model(base_name, suffix, display_suffix):
model = model_mapping.get(base_name)
if not model:
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
logger.warning(
f"Base model '{base_name}' not found for derived model '{suffix}'."
)
return
item = deepcopy(model)
item["name"] = f"models/{base_name}{suffix}"
@@ -75,7 +87,7 @@ async def list_models(
add_derived_model(name, "-search", " For Search")
if settings.IMAGE_MODELS:
for name in settings.IMAGE_MODELS:
add_derived_model(name, "-image", " For Image")
add_derived_model(name, "-image", " For Image")
if settings.THINKING_MODELS:
for name in settings.THINKING_MODELS:
add_derived_model(name, "-non-thinking", " Non Thinking")
@@ -87,7 +99,8 @@ async def list_models(
except Exception as e:
logger.error(f"Error getting Gemini models list: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error while fetching Gemini models list"
status_code=500,
detail="Internal server error while fetching Gemini models list",
) from e
@@ -96,25 +109,30 @@ async def list_models(
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
chat_service: GeminiChatService = Depends(get_chat_service),
):
"""处理 Gemini 非流式内容生成请求。"""
operation_name = "gemini_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
logger.info(f"Handling Gemini content generation request for model: {model_name}")
async with handle_route_errors(
logger, operation_name, failure_message="Content generation failed"
):
logger.info(
f"Handling Gemini content generation request for model: {model_name}"
)
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
model=model_name, request=request, api_key=api_key
)
return response
@@ -124,24 +142,50 @@ async def generate_content(
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
chat_service: GeminiChatService = Depends(get_chat_service),
):
"""处理 Gemini 流式内容生成请求。"""
operation_name = "gemini_stream_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
async with handle_route_errors(
logger, operation_name, failure_message="Streaming request initiation failed"
):
logger.info(
f"Handling Gemini streaming content generation for model: {model_name}"
)
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using allowed token: {allowed_token}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {model_name} is not supported"
)
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
raw_stream = chat_service.stream_generate_content(
model=model_name, request=request, api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
try:
# 尝试获取第一条数据,判断是正常 SSEdata: 前缀)还是错误 JSON
first_chunk = await raw_stream.__anext__()
except StopAsyncIteration:
# 如果流直接结束,退回标准 SSE 输出
return StreamingResponse(raw_stream, media_type="text/event-stream")
except Exception as e:
# 初始化流异常,直接返回 500 错误
return JSONResponse(
content={"error": {"code": e.args[0], "message": e.args[1]}},
status_code=e.args[0],
)
# 如果以 "data:" 开头,代表正常 SSE将首块和后续块一起发送
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
async def combined():
yield first_chunk
async for chunk in raw_stream:
yield chunk
return StreamingResponse(combined(), media_type="text/event-stream")

View File

@@ -1,4 +1,3 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from app.config.config import settings
@@ -6,9 +5,9 @@ from app.domain.gemini_models import GeminiContent, GeminiRequest
from app.log.logger import Logger
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.error_log.error_log_service import delete_old_error_logs
from app.service.files.files_service import get_files_service
from app.service.key.key_manager import get_key_manager_instance
from app.service.request_log.request_log_service import delete_old_request_logs_task
from app.service.files.files_service import get_files_service
from app.utils.helpers import redact_key_for_logging
logger = Logger.setup_logger("scheduler")
@@ -106,15 +105,16 @@ async def cleanup_expired_files():
try:
files_service = await get_files_service()
deleted_count = await files_service.cleanup_expired_files()
if deleted_count > 0:
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
else:
logger.info("No expired files to clean up.")
except Exception as e:
logger.error(
f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True
f"An error occurred during the scheduled file cleanup: {str(e)}",
exc_info=True,
)
@@ -122,44 +122,45 @@ def setup_scheduler():
"""设置并启动 APScheduler"""
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
# 添加检查失败密钥的定时任务
scheduler.add_job(
check_failed_keys,
"interval",
hours=settings.CHECK_INTERVAL_HOURS,
id="check_failed_keys_job",
name="Check Failed API Keys",
)
logger.info(
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
)
if settings.CHECK_INTERVAL_HOURS != 0:
scheduler.add_job(
check_failed_keys,
"interval",
hours=settings.CHECK_INTERVAL_HOURS,
id="check_failed_keys_job",
name="Check Failed API Keys",
)
logger.info(
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
)
# 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
# 新增:添加自动删除错误日志的定时任务,每天凌晨0点执行
scheduler.add_job(
delete_old_error_logs,
"cron",
hour=3,
hour=0,
minute=0,
id="delete_old_error_logs_job",
name="Delete Old Error Logs",
)
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
# 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
# 新增:添加自动删除请求日志的定时任务,每天凌晨0点执行
scheduler.add_job(
delete_old_request_logs_task,
"cron",
hour=3,
minute=5,
hour=0,
minute=0,
id="delete_old_request_logs_job",
name="Delete Old Request Logs",
)
logger.info(
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
)
# 新增:添加文件过期清理的定时任务,每小时执行一次
if getattr(settings, 'FILES_CLEANUP_ENABLED', True):
cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1)
if getattr(settings, "FILES_CLEANUP_ENABLED", True):
cleanup_interval = getattr(settings, "FILES_CLEANUP_INTERVAL_HOURS", 1)
scheduler.add_job(
cleanup_expired_files,
"interval",

View File

@@ -365,13 +365,9 @@ class GeminiChatService:
return self.response_handler.handle_response(response, model, stream=False)
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"Normal API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
@@ -379,7 +375,7 @@ class GeminiChatService:
error_type="gemini-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
request_datetime=request_datetime,
)
raise e
@@ -416,13 +412,9 @@ class GeminiChatService:
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
@@ -430,7 +422,7 @@ class GeminiChatService:
error_type="gemini-count-tokens",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
)
raise e
finally:
@@ -508,15 +500,11 @@ class GeminiChatService:
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
)
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
@@ -524,7 +512,9 @@ class GeminiChatService:
error_type="gemini-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=(
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
),
request_datetime=request_datetime,
)
@@ -537,11 +527,11 @@ class GeminiChatService:
)
else:
logger.error(f"No valid API key available after {retries} retries.")
break
raise
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)

View File

@@ -3,7 +3,6 @@
import asyncio
import datetime
import json
import re
import time
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
@@ -285,7 +284,9 @@ class OpenAIChatService:
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
messages, instruction = self.message_converter.convert(request.messages)
messages, instruction = self.message_converter.convert(
request.messages, request.model
)
payload = _build_payload(request, messages, instruction)
@@ -337,7 +338,8 @@ class OpenAIChatService:
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"API call failed for model {model}: {error_log_msg}")
# 特别记录 max_tokens 相关的错误
@@ -351,16 +353,13 @@ class OpenAIChatService:
if "parts" in error_log_msg:
logger.error("This is likely a response processing error")
match = re.search(r"status code (\d+)", error_log_msg)
status_code = int(match.group(1)) if match else 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
request_datetime=request_datetime,
)
raise e
@@ -538,27 +537,21 @@ class OpenAIChatService:
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
)
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
if isinstance(e, asyncio.TimeoutError):
status_code = 408
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
model_name=model,
error_type="openai-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=(
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
),
request_datetime=request_datetime,
)
@@ -575,7 +568,7 @@ class OpenAIChatService:
logger.error(
f"No valid API key available after {retries} retries, ceasing attempts for this request."
)
break
raise
else:
logger.error(
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
@@ -586,6 +579,7 @@ class OpenAIChatService:
logger.error(
f"Max retries ({max_retries}) reached for streaming model {model}."
)
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
@@ -598,13 +592,6 @@ class OpenAIChatService:
request_time=request_datetime,
)
if not is_success:
logger.error(
f"Streaming failed permanently for model {model} after {retries} attempts."
)
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
yield "data: [DONE]\n\n"
async def create_image_chat_completion(
self, request: ChatRequest, api_key: str
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
@@ -663,20 +650,23 @@ class OpenAIChatService:
yield "data: [DONE]\n\n"
except Exception as e:
is_success = False
error_log_msg = f"Stream image completion failed for model {model}: {e}"
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(error_log_msg)
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]},
request_msg=(
{"image_data_truncated": image_data[:1000]}
if settings.ERROR_LOG_RECORD_REQUEST_BODY
else None
),
request_datetime=request_datetime,
)
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
yield "data: [DONE]\n\n"
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
@@ -714,19 +704,23 @@ class OpenAIChatService:
return result
except Exception as e:
is_success = False
error_log_msg = f"Normal image completion failed for model {model}: {e}"
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(error_log_msg)
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]},
request_msg=(
{"image_data_truncated": image_data[:1000]}
if settings.ERROR_LOG_RECORD_REQUEST_BODY
else None
),
request_datetime=request_datetime,
)
raise e
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)

View File

@@ -2,7 +2,6 @@
import datetime
import json
import re
import time
from typing import Any, AsyncGenerator, Dict, List
@@ -278,13 +277,9 @@ class GeminiChatService:
return self.response_handler.handle_response(response, model, stream=False)
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"Normal API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
@@ -292,7 +287,7 @@ class GeminiChatService:
error_type="gemini-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
request_datetime=request_datetime,
)
raise e
@@ -356,15 +351,11 @@ class GeminiChatService:
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
)
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
@@ -372,7 +363,9 @@ class GeminiChatService:
error_type="gemini-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=(
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
),
request_datetime=request_datetime,
)
@@ -385,11 +378,11 @@ class GeminiChatService:
)
else:
logger.error(f"No valid API key available after {retries} retries.")
break
raise
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)

View File

@@ -1,24 +1,31 @@
# app/services/chat/api_client.py
from typing import Dict, Any, AsyncGenerator, Optional
import httpx
import random
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, Optional
import httpx
from app.config.config import settings
from app.log.logger import get_api_client_logger
from app.core.constants import DEFAULT_TIMEOUT
from app.log.logger import get_api_client_logger
logger = get_api_client_logger()
class ApiClient(ABC):
"""API客户端基类"""
@abstractmethod
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def generate_content(
self, payload: Dict[str, Any], model: str, api_key: str
) -> Dict[str, Any]:
pass
@abstractmethod
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
async def stream_generate_content(
self, payload: Dict[str, Any], model: str, api_key: str
) -> AsyncGenerator[str, None]:
pass
@@ -50,7 +57,7 @@ class GeminiApiClient(ApiClient):
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
"""获取可用的 Gemini 模型列表"""
timeout = httpx.Timeout(timeout=5)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -73,11 +80,13 @@ class GeminiApiClient(ApiClient):
except httpx.RequestError as e:
logger.error(f"请求模型列表失败: {e}")
return None
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def generate_content(
self, payload: Dict[str, Any], model: str, api_key: str
) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -85,42 +94,33 @@ class GeminiApiClient(ApiClient):
else:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
try:
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(f"API call failed - Status: {response.status_code}, Content: {error_content}")
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
response_data = response.json()
# 检查响应结构的基本信息
if not response_data.get("candidates"):
logger.warning("No candidates found in API response")
return response_data
except httpx.TimeoutException as e:
logger.error(f"Request timeout: {e}")
raise Exception(f"Request timeout: {e}")
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
raise Exception(f"Request error: {e}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise
response = await client.post(url, json=payload, headers=headers)
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
if response.status_code != 200:
error_content = response.text
logger.error(
f"API call failed - Status: {response.status_code}, Content: {error_content}"
)
raise Exception(response.status_code, error_content)
response_data = response.json()
# 检查响应结构的基本信息
if not response_data.get("candidates"):
logger.warning("No candidates found in API response")
return response_data
async def stream_generate_content(
self, payload: Dict[str, Any], model: str, api_key: str
) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -132,15 +132,19 @@ class GeminiApiClient(ApiClient):
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
async with client.stream(
method="POST", url=url, json=payload, headers=headers
) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
raise Exception(response.status_code, error_msg)
async for line in response.aiter_lines():
yield line
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def count_tokens(
self, payload: Dict[str, Any], model: str, api_key: str
) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
@@ -158,14 +162,16 @@ class GeminiApiClient(ApiClient):
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
raise Exception(response.status_code, error_content)
return response.json()
async def embed_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def embed_content(
self, payload: Dict[str, Any], model: str, api_key: str
) -> Dict[str, Any]:
"""单一嵌入内容生成"""
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -177,32 +183,22 @@ class GeminiApiClient(ApiClient):
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:embedContent?key={api_key}"
try:
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}")
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
except httpx.TimeoutException as e:
logger.error(f"Embedding request timeout: {e}")
raise Exception(f"Request timeout: {e}")
except httpx.RequestError as e:
logger.error(f"Embedding request error: {e}")
raise Exception(f"Request error: {e}")
except Exception as e:
logger.error(f"Unexpected embedding error: {e}")
raise
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(
f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}"
)
raise Exception(response.status_code, error_content)
return response.json()
async def batch_embed_contents(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def batch_embed_contents(
self, payload: Dict[str, Any], model: str, api_key: str
) -> Dict[str, Any]:
"""批量嵌入内容生成"""
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -214,26 +210,14 @@ class GeminiApiClient(ApiClient):
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:batchEmbedContents?key={api_key}"
try:
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}")
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
except httpx.TimeoutException as e:
logger.error(f"Batch embedding request timeout: {e}")
raise Exception(f"Request timeout: {e}")
except httpx.RequestError as e:
logger.error(f"Batch embedding request error: {e}")
raise Exception(f"Request error: {e}")
except Exception as e:
logger.error(f"Unexpected batch embedding error: {e}")
raise
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(
f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}"
)
raise Exception(response.status_code, error_content)
return response.json()
class OpenaiApiClient(ApiClient):
@@ -242,7 +226,7 @@ class OpenaiApiClient(ApiClient):
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
self.base_url = base_url
self.timeout = timeout
def _prepare_headers(self, api_key: str) -> Dict[str, str]:
headers = {"Authorization": f"Bearer {api_key}"}
if settings.CUSTOM_HEADERS:
@@ -267,12 +251,16 @@ class OpenaiApiClient(ApiClient):
response = await client.get(url, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
raise Exception(response.status_code, error_content)
return response.json()
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
async def generate_content(
self, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}")
logger.info(
f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}"
)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -287,10 +275,12 @@ class OpenaiApiClient(ApiClient):
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
raise Exception(response.status_code, error_content)
return response.json()
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
async def stream_generate_content(
self, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
if settings.PROXIES:
@@ -303,17 +293,21 @@ class OpenaiApiClient(ApiClient):
headers = self._prepare_headers(api_key)
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/chat/completions"
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
async with client.stream(
method="POST", url=url, json=payload, headers=headers
) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
raise Exception(response.status_code, error_msg)
async for line in response.aiter_lines():
yield line
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
async def create_embeddings(
self, input: str, model: str, api_key: str
) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -332,10 +326,12 @@ class OpenaiApiClient(ApiClient):
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
raise Exception(response.status_code, error_content)
return response.json()
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
async def generate_images(
self, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
@@ -352,5 +348,5 @@ class OpenaiApiClient(ApiClient):
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
raise Exception(response.status_code, error_content)
return response.json()

View File

@@ -1,5 +1,4 @@
import datetime
import re
import time
from typing import List, Union
@@ -56,13 +55,9 @@ class EmbeddingService:
raise e
except Exception as e:
is_success = False
status_code = 500
error_log_msg = f"Generic error: {e}"
logger.error(f"Error creating embedding (Exception): {error_log_msg}")
match = re.search(r"status code (\d+)", str(e))
if match:
status_code = int(match.group(1))
else:
status_code = 500
raise e
finally:
end_time = time.perf_counter()
@@ -74,7 +69,11 @@ class EmbeddingService:
error_type="openai-embedding",
error_log=error_log_msg,
error_code=status_code,
request_msg=request_msg_log,
request_msg=(
request_msg_log
if settings.ERROR_LOG_RECORD_REQUEST_BODY
else None
),
request_datetime=request_datetime,
)
await add_request_log(

View File

@@ -1,7 +1,6 @@
# app/service/embedding/gemini_embedding_service.py
import datetime
import re
import time
from typing import Any, Dict
@@ -69,13 +68,9 @@ class GeminiEmbeddingService:
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"Single embedding API call failed: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
@@ -83,7 +78,7 @@ class GeminiEmbeddingService:
error_type="gemini-embed-single",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
request_datetime=request_datetime,
)
raise e
@@ -119,13 +114,9 @@ class GeminiEmbeddingService:
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"Batch embedding API call failed: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
@@ -133,7 +124,7 @@ class GeminiEmbeddingService:
error_type="gemini-embed-batch",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
request_datetime=request_datetime,
)
raise e

View File

@@ -1,4 +1,4 @@
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, func, select
@@ -28,7 +28,7 @@ async def delete_old_error_logs():
)
return
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
logger.info(
f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."

View File

@@ -9,6 +9,7 @@ from app.config.config import settings
from app.core.constants import VALID_IMAGE_RATIOS
from app.domain.openai_models import ImageGenerationRequest
from app.log.logger import get_image_create_logger
from app.utils.helpers import is_image_upload_configured
from app.utils.uploader import ImageUploaderFactory
logger = get_image_create_logger()
@@ -97,12 +98,18 @@ class ImageCreateService:
image_data = generated_image.image.image_bytes
image_uploader = None
if request.response_format == "b64_json":
# Return base64 if explicitly requested or if no uploader is configured
if (
request.response_format == "b64_json"
or not is_image_upload_configured(settings)
):
base64_image = base64.b64encode(image_data).decode("utf-8")
images_data.append(
{"b64_json": base64_image, "revised_prompt": request.prompt}
)
continue
else:
# Upload to configured provider
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
@@ -115,6 +122,7 @@ class ImageCreateService:
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY,
api_url=settings.PICGO_API_URL,
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
@@ -123,6 +131,16 @@ class ImageCreateService:
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
)
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
access_key=settings.OSS_ACCESS_KEY,
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
bucket_name=settings.OSS_BUCKET_NAME,
endpoint=settings.OSS_ENDPOINT,
region=settings.OSS_REGION,
use_internal=False
)
else:
raise ValueError(
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"

View File

@@ -1,6 +1,4 @@
import datetime
import json
import re
import time
from typing import Any, AsyncGenerator, Dict, Union
@@ -80,13 +78,9 @@ class OpenAICompatiableService:
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.error(f"Normal API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
@@ -94,7 +88,7 @@ class OpenAICompatiableService:
error_type="openai-compatiable-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=request,
request_msg=request if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
)
raise e
finally:
@@ -138,15 +132,11 @@ class OpenAICompatiableService:
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
status_code = e.args[0]
error_log_msg = e.args[1]
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
)
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
@@ -154,7 +144,9 @@ class OpenAICompatiableService:
error_type="openai-compatiable-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
request_msg=(
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
),
request_datetime=request_datetime,
)
@@ -170,14 +162,14 @@ class OpenAICompatiableService:
logger.error(
f"No valid API key available after {retries} retries."
)
break
raise
else:
logger.error("KeyManager not available for retry logic.")
break
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
@@ -189,6 +181,3 @@ class OpenAICompatiableService:
latency_ms=latency_ms,
request_time=request_datetime,
)
if not is_success and retries >= max_retries:
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"

View File

@@ -2,12 +2,12 @@
Service for request log operations.
"""
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from sqlalchemy import delete
from app.database.connection import database
from app.config.config import settings
from app.database.connection import database
from app.database.models import RequestLog
from app.log.logger import get_request_log_logger
@@ -30,7 +30,7 @@ async def delete_old_request_logs_task():
)
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)

View File

@@ -3,14 +3,16 @@
继承自原始聊天服务添加原生Gemini TTS支持单人和多人保持向后兼容
"""
import time
import datetime
import time
from typing import Any, Dict
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.tts.native.tts_response_handler import TTSResponseHandler
from app.config.config import settings
from app.database.services import add_error_log, add_request_log
from app.domain.gemini_models import GeminiRequest
from app.log.logger import get_gemini_logger
from app.database.services import add_request_log, add_error_log
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.tts.native.tts_response_handler import TTSResponseHandler
logger = get_gemini_logger()
@@ -28,7 +30,9 @@ class TTSGeminiChatService(GeminiChatService):
super().__init__(base_url, key_manager)
# 使用TTS响应处理器替换原始处理器
self.response_handler = TTSResponseHandler()
logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support")
logger.info(
"TTS Gemini Chat Service initialized with multi-speaker TTS support"
)
async def generate_content(
self, model: str, request: GeminiRequest, api_key: str
@@ -55,7 +59,9 @@ class TTSGeminiChatService(GeminiChatService):
logger.error(f"TTS API call failed with error: {e}")
raise
async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
async def _handle_tts_request(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""
处理TTS特定的请求包含完整的日志记录功能
"""
@@ -89,14 +95,24 @@ class TTSGeminiChatService(GeminiChatService):
if request.generationConfig:
# 添加TTS特定字段
if request.generationConfig.responseModalities:
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
payload["generationConfig"][
"responseModalities"
] = request.generationConfig.responseModalities
logger.info(
f"Added responseModalities: {request.generationConfig.responseModalities}"
)
if request.generationConfig.speechConfig:
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
payload["generationConfig"][
"speechConfig"
] = request.generationConfig.speechConfig
logger.info(
f"Added speechConfig: {request.generationConfig.speechConfig}"
)
else:
logger.warning("No generationConfig found in request, TTS fields may be missing")
logger.warning(
"No generationConfig found in request, TTS fields may be missing"
)
logger.info(f"TTS payload before API call: {payload}")
@@ -117,6 +133,7 @@ class TTSGeminiChatService(GeminiChatService):
# 尝试从错误消息中提取状态码
import re
match = re.search(r"status code (\d+)", error_msg)
if match:
status_code = int(match.group(1))
@@ -130,7 +147,11 @@ class TTSGeminiChatService(GeminiChatService):
error_type="tts-api-error",
error_log=error_msg,
error_code=status_code,
request_msg=request.model_dump(exclude_none=False)
request_msg=(
request.model_dump(exclude_none=False)
if settings.ERROR_LOG_RECORD_REQUEST_BODY
else None
),
)
logger.error(f"TTS API call failed: {error_msg}")
@@ -147,5 +168,5 @@ class TTSGeminiChatService(GeminiChatService):
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)

View File

@@ -40,7 +40,7 @@ class TTSService:
error_log_msg = ""
try:
client = genai.Client(api_key=api_key)
response =await client.aio.models.generate_content(
response = await client.aio.models.generate_content(
model=settings.TTS_MODEL,
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
config={
@@ -48,7 +48,11 @@ class TTSService:
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {
"voice_name": request.voice if request.voice in TTS_VOICE_NAMES else settings.TTS_VOICE_NAME
"voice_name": (
request.voice
if request.voice in TTS_VOICE_NAMES
else settings.TTS_VOICE_NAME
)
}
}
},
@@ -59,7 +63,9 @@ class TTSService:
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].inline_data
):
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
raw_audio_data = (
response.candidates[0].content.parts[0].inline_data.data
)
is_success = True
status_code = 200
return _create_wav_file(raw_audio_data)
@@ -83,13 +89,17 @@ class TTSService:
error_type="google-tts",
error_log=error_log_msg,
error_code=status_code,
request_msg=request.input
)
request_msg=(
request.input
if settings.ERROR_LOG_RECORD_REQUEST_BODY
else None
),
)
await add_request_log(
model_name=settings.TTS_MODEL,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)
request_time=request_datetime,
)

View File

@@ -104,6 +104,24 @@ document.addEventListener("DOMContentLoaded", function () {
});
}
// 检查间隔小时数输入控制
const checkIntervalInput = document.getElementById("CHECK_INTERVAL_HOURS");
if (checkIntervalInput) {
checkIntervalInput.addEventListener("input", function () {
let value = parseFloat(this.value);
if (isNaN(value) || value < 0) {
this.value = 0;
}
});
checkIntervalInput.addEventListener("change", function () {
let value = parseFloat(this.value);
if (isNaN(value) || value < 0) {
this.value = 0;
}
});
}
// Toggle switch events
const toggleSwitches = document.querySelectorAll(".toggle-switch");
toggleSwitches.forEach((toggleSwitch) => {
@@ -771,6 +789,10 @@ async function initConfig() {
if (typeof config.AUTO_DELETE_ERROR_LOGS_DAYS === "undefined") {
config.AUTO_DELETE_ERROR_LOGS_DAYS = 7;
}
// 错误日志是否记录请求体(默认不记录)
if (typeof config.ERROR_LOG_RECORD_REQUEST_BODY === "undefined") {
config.ERROR_LOG_RECORD_REQUEST_BODY = false;
}
// --- 结束:处理自动删除错误日志配置的默认值 ---
// --- 新增:处理自动删除请求日志配置的默认值 ---

View File

@@ -541,30 +541,13 @@ function showVerificationResultModal(data) {
const errorGroups = {};
Object.entries(failedKeys).forEach(([key, error]) => {
// 提取错误码或使用完整错误信息作为分组键
let errorCode = error;
// 尝试提取常见的错误码模式
const errorCodePatterns = [
/status code (\d+)/,
];
for (const pattern of errorCodePatterns) {
const match = error.match(pattern);
if (match) {
errorCode = match[1] || match[0];
break;
}
}
// 如果没有匹配到特定模式使用500
if (errorCode === error) {
errorCode = 500;
}
let errorCode = error["error_code"];
let errorMessage = error["error_message"];
if (!errorGroups[errorCode]) {
errorGroups[errorCode] = [];
}
errorGroups[errorCode].push({ key, error });
errorGroups[errorCode].push({ key, errorMessage });
});
// 创建分组展示容器
@@ -609,7 +592,7 @@ function showVerificationResultModal(data) {
const keysList = document.createElement("div");
keysList.className = "group-keys-list space-y-1";
keyErrorPairs.forEach(({ key, error }) => {
keyErrorPairs.forEach(({ key, errorMessage }) => {
const keyItem = document.createElement("div");
keyItem.className = "flex flex-col items-start bg-gray-50 p-2 rounded border";
@@ -624,7 +607,7 @@ function showVerificationResultModal(data) {
const detailsButton = document.createElement("button");
detailsButton.className = "ml-2 px-2 py-0.5 bg-red-200 hover:bg-red-300 text-red-700 text-xs rounded transition-colors";
detailsButton.innerHTML = '<i class="fas fa-info-circle mr-1"></i>详情';
detailsButton.dataset.error = error;
detailsButton.dataset.error = errorMessage;
detailsButton.onclick = (e) => {
e.stopPropagation();
const button = e.currentTarget;
@@ -984,7 +967,6 @@ function initializeGlobalBatchVerificationHandlers() {
document.getElementById("verifyModal").classList.add("hidden");
};
// executeVerifyAll 变为 initializeGlobalBatchVerificationHandlers 的局部函数
async function executeVerifyAll(type) {
closeVerifyModal();
const keysToVerify = getSelectedKeys(type);
@@ -1055,8 +1037,6 @@ function initializeGlobalBatchVerificationHandlers() {
invalid_count: Object.keys(allFailedKeys).length
});
}
// The confirmButton.onclick in showVerifyModal (defined earlier in initializeGlobalBatchVerificationHandlers)
// will correctly reference this local executeVerifyAll due to closure.
}
// --- 进度条模态框函数 ---
@@ -2550,6 +2530,7 @@ function showVerifyModalForAllKeys(allKeys) {
modalElement.classList.remove("hidden");
}
// 执行验证所有密钥
async function executeVerifyAllKeys(allKeys) {
closeVerifyModal();

View File

@@ -51,7 +51,7 @@
</div>
<h2 class="text-3xl font-extrabold text-center text-gray-800 mb-8 animate-slide-down">
<img src="/static/icons/logo.png" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
<img src="{{ static_url('icons/logo.png') }}" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
Gemini Balance
</h2>

View File

@@ -4,21 +4,21 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>{% block title %}Gemini Balance{% endblock %}</title>
<link rel="manifest" href="/static/manifest.json" />
<link rel="manifest" href="{{ static_url('manifest.json') }}" />
<meta name="theme-color" content="#4F46E5" />
<meta name="apple-mobile-web-app-capable" content="yes" />
<meta name="apple-mobile-web-app-status-bar-style" content="black" />
<meta name="apple-mobile-web-app-title" content="GBalance" />
<link rel="icon" href="/static/icons/icon-192x192.png" />
<link rel="icon" href="{{ static_url('icons/icon-192x192.png') }}" />
<link
href="/static/css/fonts.css"
href="{{ static_url('css/fonts.css') }}"
rel="stylesheet"
/>
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"
/>
<script src="/static/js/tailwindcss.js"></script>
<script src="{{ static_url('js/tailwindcss.js') }}"></script>
<script>
tailwind.config = {
theme: {

File diff suppressed because it is too large Load Diff

View File

@@ -45,6 +45,179 @@ endblock %} {% block head_extra_styles %}
.search-container {
grid-template-columns: 1fr;
}
/* 移动端主容器布局 */
.mobile-buttons-container {
display: flex !important;
flex-direction: column !important;
gap: 1rem !important;
align-items: stretch !important;
width: 100% !important;
padding: 0 !important;
margin: 0 !important;
}
/* 移动端搜索控件布局优化 */
.mobile-search-controls {
grid-template-columns: 1fr !important;
gap: 0.75rem !important;
width: 100% !important;
margin-bottom: 0.5rem !important;
}
/* 按钮容器在移动端的布局 */
.buttons-container-responsive {
display: flex !important;
flex-direction: column !important;
gap: 0.5rem !important;
width: 100% !important;
align-items: stretch !important;
justify-content: stretch !important;
}
/* 移动端所有按钮样式 */
.buttons-container-responsive button {
width: 100% !important;
max-width: 100% !important;
justify-content: center !important;
text-align: center !important;
min-width: 0 !important;
flex-shrink: 0 !important;
box-sizing: border-box !important;
padding: 0.5rem 1rem !important;
font-size: 0.875rem !important;
white-space: nowrap !important;
overflow: hidden !important;
text-overflow: ellipsis !important;
}
}
/* 中等屏幕优化 */
@media (max-width: 1024px) and (min-width: 769px) {
.buttons-container-responsive {
flex-wrap: wrap !important;
justify-content: center !important;
}
.buttons-container-responsive button {
flex-shrink: 1 !important;
min-width: 0 !important;
padding-left: 0.75rem !important;
padding-right: 0.75rem !important;
}
}
/* 小屏幕(手机)特殊优化 - 确保按钮在边框内 */
@media (max-width: 640px) {
/* 强制重写主容器布局 */
.mobile-buttons-container {
display: flex !important;
flex-direction: column !important;
width: 100% !important;
padding: 0 !important;
margin: 0 !important;
gap: 1rem !important;
overflow: visible !important;
}
/* 搜索区域在移动端占满宽度 */
.mobile-search-controls {
width: 100% !important;
box-sizing: border-box !important;
}
/* 按钮区域完全重新布局 */
.buttons-container-responsive {
display: flex !important;
flex-direction: column !important;
width: 100% !important;
max-width: 100% !important;
gap: 0.5rem !important;
padding: 0 !important;
margin: 0 !important;
box-sizing: border-box !important;
overflow: visible !important;
}
/* 所有按钮统一样式 */
.buttons-container-responsive button {
display: flex !important;
align-items: center !important;
justify-content: center !important;
width: 100% !important;
max-width: 100% !important;
box-sizing: border-box !important;
padding: 0.5rem 1rem !important;
margin: 0 !important;
font-size: 0.875rem !important;
line-height: 1.25rem !important;
border-radius: 0.5rem !important;
white-space: nowrap !important;
overflow: hidden !important;
text-overflow: ellipsis !important;
flex-shrink: 0 !important;
}
/* 特别针对清空全部按钮 */
#deleteAllLogsBtn {
background-color: #f87171 !important;
border: 1px solid #f87171 !important;
}
#deleteAllLogsBtn:hover {
background-color: #ef4444 !important;
border: 1px solid #ef4444 !important;
}
/* 确保容器不会溢出父级 */
.mobile-buttons-container,
.mobile-buttons-container > *,
.buttons-container-responsive,
.buttons-container-responsive > * {
max-width: 100% !important;
box-sizing: border-box !important;
}
/* 额外的安全边距控制 */
.mobile-buttons-container .grid {
padding-left: 0 !important;
padding-right: 0 !important;
margin-left: 0 !important;
margin-right: 0 !important;
}
/* 确保主内容区域有适当的内边距 */
.rounded-xl.p-6 {
padding-left: 1rem !important;
padding-right: 1rem !important;
}
}
/* 超小屏幕额外优化 */
@media (max-width: 480px) {
.mobile-buttons-container {
gap: 0.75rem !important;
}
.buttons-container-responsive {
gap: 0.4rem !important;
}
.buttons-container-responsive button {
padding: 0.4rem 0.8rem !important;
font-size: 0.8rem !important;
}
/* 主容器内边距进一步缩小 */
.rounded-xl.p-6 {
padding-left: 0.75rem !important;
padding-right: 0.75rem !important;
}
/* 确保清空全部按钮文字不会太挤 */
#deleteAllLogsBtn i {
margin-right: 0.25rem !important;
}
}
input[type="text"],
@@ -586,7 +759,7 @@ endblock %} {% block head_extra_styles %}
class="text-3xl font-extrabold text-center text-gray-800 mb-4"
>
<img
src="/static/icons/logo.png"
src="{{ static_url('icons/logo.png') }}"
alt="Gemini Balance Logo"
class="h-9 inline-block align-middle mr-2"
/>
@@ -636,10 +809,10 @@ endblock %} {% block head_extra_styles %}
<!-- 搜索与操作控件 -->
<div
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6"
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6 mobile-buttons-container"
>
<div
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full"
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full mobile-search-controls"
>
<input
type="text"
@@ -684,7 +857,7 @@ endblock %} {% block head_extra_styles %}
</div>
</div>
</div>
<div class="flex items-center gap-3 flex-shrink-0">
<div class="flex items-center gap-3 flex-shrink-0 buttons-container-responsive">
<button
id="searchBtn"
class="flex items-center justify-center px-4 py-1.5 bg-blue-600 hover:bg-blue-700 text-white rounded-lg font-medium transition-all duration-200 shadow-sm hover:shadow-md whitespace-nowrap"
@@ -1041,7 +1214,7 @@ endblock %} {% block head_extra_styles %}
</div>
</div>
{% endblock %} {% block body_scripts %}
<script src="/static/js/error_logs.js"></script>
<script src="{{ static_url('js/error_logs.js') }}"></script>
<script>
// error_logs.html specific JS initialization (if any)
// e.g., initialize date pickers or other elements if needed

View File

@@ -1,14 +1,17 @@
"""
通用工具函数模块
"""
import json
import re
import base64
import requests
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
import logging
import base64
import json
import logging
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import requests
from app.config.config import Settings
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
helper_logger = logging.getLogger("app.utils")
@@ -20,23 +23,25 @@ VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
"""
从 base64 字符串中提取 MIME 类型和数据
Args:
base64_string: 可能包含 MIME 类型信息的 base64 字符串
Returns:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
if base64_string.startswith("data:"):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
mime_type = (
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
@@ -44,20 +49,20 @@ def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
def convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
Raises:
Exception: 如果获取图片失败
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
img_data = base64.b64encode(response.content).decode("utf-8")
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
@@ -66,64 +71,66 @@ def convert_image_to_base64(url: str) -> str:
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
"""
格式化JSON响应
Args:
data: 要格式化的数据
indent: 缩进空格数
Returns:
str: 格式化后的JSON字符串
"""
return json.dumps(data, indent=indent, ensure_ascii=False)
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
def parse_prompt_parameters(
prompt: str, default_ratio: str = "1:1"
) -> Tuple[str, int, str]:
"""
从prompt中解析参数
支持的格式:
- {n:数量} 例如: {n:2} 生成2张图片
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
Args:
prompt: 提示文本
default_ratio: 默认比例
Returns:
tuple: (清理后的提示文本, 图片数量, 比例)
"""
# 默认值
n = 1
aspect_ratio = default_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
n_match = re.search(r"{n:(\d+)}", prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
prompt = prompt.replace(n_match.group(0), "").strip()
# 解析ratio参数
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
prompt = prompt.replace(ratio_match.group(0), "").strip()
return prompt, n, aspect_ratio
def extract_image_urls_from_markdown(text: str) -> List[str]:
"""
从Markdown文本中提取图片URL
Args:
text: Markdown文本
Returns:
List[str]: 图片URL列表
"""
@@ -135,23 +142,22 @@ def extract_image_urls_from_markdown(text: str) -> List[str]:
def is_valid_api_key(key: str) -> bool:
"""
检查API密钥格式是否有效
Args:
key: API密钥
Returns:
bool: 如果密钥格式有效则返回True
"""
# 检查Gemini API密钥格式
if key.startswith('AIza'):
if key.startswith("AIza"):
return len(key) >= 30
# 检查OpenAI API密钥格式
if key.startswith('sk-'):
return len(key) >= 30
return False
# 检查OpenAI API密钥格式
if key.startswith("sk-"):
return len(key) >= 30
return False
def redact_key_for_logging(key: str) -> str:
@@ -177,15 +183,49 @@ def get_current_version(default_version: str = "0.0.0") -> str:
"""Reads the current version from the VERSION file."""
version_file = VERSION_FILE_PATH
try:
with version_file.open('r', encoding='utf-8') as f:
with version_file.open("r", encoding="utf-8") as f:
version = f.read().strip()
if not version:
helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
helper_logger.warning(
f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'."
)
return default_version
return version
except FileNotFoundError:
helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
helper_logger.warning(
f"VERSION file not found at '{version_file}'. Using default version '{default_version}'."
)
return default_version
except IOError as e:
helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
helper_logger.error(
f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'."
)
return default_version
def is_image_upload_configured(settings: Settings) -> bool:
"""Return True only if a valid upload provider is selected and all required settings for that provider are present."""
provider = (getattr(settings, "UPLOAD_PROVIDER", "") or "").strip().lower()
if provider == "smms":
return bool(getattr(settings, "SMMS_SECRET_TOKEN", ""))
if provider == "picgo":
return bool(getattr(settings, "PICGO_API_KEY", ""))
if provider == "aliyun_oss":
return all(
[
getattr(settings, "OSS_ACCESS_KEY", ""),
getattr(settings, "OSS_ACCESS_KEY_SECRET", ""),
getattr(settings, "OSS_BUCKET_NAME", ""),
getattr(settings, "OSS_ENDPOINT", ""),
getattr(settings, "OSS_REGION", "")
]
)
if provider == "cloudflare_imgbed":
return all(
[
getattr(settings, "CLOUDFLARE_IMGBED_URL", ""),
getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""),
]
)
return False

127
app/utils/static_version.py Normal file
View File

@@ -0,0 +1,127 @@
"""
静态资源版本控制工具
用于给CSS和JS文件添加版本参数避免浏览器缓存问题
"""
import hashlib
import time
from functools import lru_cache
from pathlib import Path
from typing import Dict
from app.utils.helpers import get_current_version
class StaticVersionManager:
"""静态资源版本管理器"""
def __init__(self, static_dir: str = "app/static"):
self.static_dir = Path(static_dir)
self._version_cache: Dict[str, str] = {}
self._use_file_hash = True # 是否使用文件哈希作为版本号
def get_version_for_file(self, file_path: str) -> str:
"""
获取文件的版本号
Args:
file_path: 相对于static目录的文件路径'css/fonts.css'
Returns:
版本号字符串
"""
if self._use_file_hash:
return self._get_file_hash_version(file_path)
else:
return self._get_app_version()
def _get_file_hash_version(self, file_path: str) -> str:
"""基于文件内容生成哈希版本号"""
# 如果已经缓存过,直接返回
if file_path in self._version_cache:
return self._version_cache[file_path]
full_path = self.static_dir / file_path
if not full_path.exists():
# 文件不存在使用应用版本号作为fallback
version = self._get_app_version()
else:
try:
# 读取文件内容并计算MD5哈希
with open(full_path, "rb") as f:
content = f.read()
hash_object = hashlib.md5(content)
version = hash_object.hexdigest()[:8] # 取前8位
except Exception:
# 读取失败使用应用版本号作为fallback
version = self._get_app_version()
# 缓存结果
self._version_cache[file_path] = version
return version
def _get_app_version(self) -> str:
"""获取应用程序版本号"""
try:
return get_current_version().replace(".", "")
except Exception:
# 如果获取版本失败,使用时间戳
return str(int(time.time()))
def get_versioned_url(self, file_path: str) -> str:
"""
获取带版本参数的URL
Args:
file_path: 相对于static目录的文件路径
Returns:
带版本参数的URL
"""
version = self.get_version_for_file(file_path)
return f"/static/{file_path}?v={version}"
def clear_cache(self):
"""清空版本缓存"""
self._version_cache.clear()
# 全局实例
_static_version_manager = StaticVersionManager()
def get_static_url(file_path: str) -> str:
"""
获取静态资源的版本化URL
Args:
file_path: 相对于static目录的文件路径
Returns:
带版本参数的完整URL
Example:
get_static_url('css/fonts.css') -> '/static/css/fonts.css?v=a1b2c3d4'
get_static_url('js/config_editor.js') -> '/static/js/config_editor.js?v=e5f6g7h8'
"""
return _static_version_manager.get_versioned_url(file_path)
def clear_static_cache():
"""清空静态资源版本缓存"""
_static_version_manager.clear_cache()
@lru_cache(maxsize=128)
def get_cached_static_url(file_path: str) -> str:
"""
获取缓存的静态资源URL用于开发环境
Args:
file_path: 相对于static目录的文件路径
Returns:
带版本参数的完整URL
"""
return get_static_url(file_path)

View File

@@ -2,6 +2,12 @@ import requests
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
from enum import Enum
from typing import Optional, Any
import hashlib
import base64
import hmac
from datetime import datetime
from urllib.parse import quote
from app.log.logger import get_image_create_logger
class UploadErrorType(Enum):
"""上传错误类型枚举"""
@@ -179,9 +185,22 @@ class PicGoUploader(ImageUploader):
"""
try:
# 准备请求头
headers = {
"X-API-Key": self.api_key
}
headers = {}
# 构建请求URL
request_url = self.api_url
# 判断是否为默认PicGo URL如果是则使用header认证否则使用URL参数认证
if self.api_url == "https://www.picgo.net/api/1/upload":
headers["X-API-Key"] = self.api_key
else:
# 对于自定义URL将API key作为查询参数添加到URL中
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
parsed_url = urlparse(request_url)
query_params = parse_qs(parsed_url.query)
query_params["key"] = self.api_key
new_query = urlencode(query_params, doseq=True)
request_url = urlunparse(parsed_url._replace(query=new_query))
# 准备文件数据
files = {
@@ -190,7 +209,7 @@ class PicGoUploader(ImageUploader):
# 发送请求
response = requests.post(
self.api_url,
request_url,
headers=headers,
files=files
)
@@ -201,6 +220,34 @@ class PicGoUploader(ImageUploader):
# 解析响应
result = response.json()
# 处理自定义PicGo服务器的响应格式
if "success" in result and "result" in result:
# 自定义PicGo服务器格式: {"success": true, "result": ["url"]}
if result["success"]:
image_url = result["result"][0] if result["result"] and len(result["result"]) > 0 else ""
image_metadata = ImageMetadata(
width=0,
height=0,
filename=filename,
size=0,
url=image_url,
delete_url=None
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
else:
raise UploadError(
message="Upload failed",
error_type=UploadErrorType.SERVER_ERROR,
status_code=400,
details=result
)
# 处理官方PicGo服务器的响应格式
# 验证上传是否成功
if result.get("status_code") != 200:
error_message = "Upload failed"
@@ -259,6 +306,191 @@ class PicGoUploader(ImageUploader):
)
class AliyunOSSUploader(ImageUploader):
"""阿里云OSS图片上传器"""
def __init__(self, access_key: str, access_key_secret: str, bucket_name: str,
endpoint: str, region: str, use_internal: bool = False):
"""
初始化阿里云OSS上传器
Args:
access_key: OSS访问密钥ID
access_key_secret: OSS访问密钥
bucket_name: OSS存储桶名称
endpoint: OSS端点地址
region: OSS区域
use_internal: 是否使用内网端点
"""
self.access_key = access_key
self.access_key_secret = access_key_secret
self.bucket_name = bucket_name
self.endpoint = endpoint
self.region = region
self.use_internal = use_internal
self.logger = get_image_create_logger()
# 构建请求URL
if not endpoint.startswith(('http://', 'https://')):
self.base_url = f"https://{bucket_name}.{endpoint}"
else:
self.base_url = f"{endpoint}/{bucket_name}"
self.logger.info(f"Initialized AliyunOSSUploader for bucket: {bucket_name}, region: {region}")
def _sign_request(self, method: str, path: str, headers: dict, content: bytes = b'') -> dict:
"""
为OSS请求生成签名
Args:
method: HTTP方法
path: 请求路径
headers: 请求头
content: 请求内容
Returns:
包含签名的请求头
"""
# 计算Content-MD5
content_md5 = base64.b64encode(hashlib.md5(content).digest()).decode('utf-8') if content else ''
# 设置日期
date = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
# 更新headers
headers['Date'] = date
if content_md5:
headers['Content-MD5'] = content_md5
headers['Content-Type'] = headers.get('Content-Type', 'image/png')
# 构建CanonicalizedOSSHeaders
oss_headers = []
for key, value in sorted(headers.items()):
if key.lower().startswith('x-oss-'):
oss_headers.append(f"{key.lower()}:{value}")
canonicalized_oss_headers = '\n'.join(oss_headers)
if canonicalized_oss_headers:
canonicalized_oss_headers += '\n'
# 构建CanonicalizedResource
canonicalized_resource = f"/{self.bucket_name}{path}"
# 构建StringToSign
string_to_sign = f"{method}\n{content_md5}\n{headers.get('Content-Type', '')}\n{date}\n{canonicalized_oss_headers}{canonicalized_resource}"
# 计算签名
signature = base64.b64encode(
hmac.new(
self.access_key_secret.encode('utf-8'),
string_to_sign.encode('utf-8'),
hashlib.sha1
).digest()
).decode('utf-8')
# 添加Authorization头
headers['Authorization'] = f"OSS {self.access_key}:{signature}"
return headers
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到阿里云OSS
Args:
file: 图片文件二进制数据
filename: 文件名将作为OSS对象的key
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
# 记录开始上传的日志
self.logger.info(f"Starting OSS upload for file: {filename}, size: {len(file)} bytes")
try:
# 构建对象路径
object_key = f"/{filename}"
# 准备请求头
headers = {
'Content-Type': 'image/png',
'x-oss-object-acl': 'public-read' # 设置为公共读
}
# 签名请求
signed_headers = self._sign_request('PUT', object_key, headers, file)
# 构建完整URL
upload_url = f"{self.base_url}{object_key}"
self.logger.debug(f"OSS upload URL: {upload_url}")
# 发送请求
response = requests.put(
upload_url,
data=file,
headers=signed_headers
)
# 检查响应状态
if response.status_code != 200:
error_msg = f"OSS upload failed with status {response.status_code}, response: {response.text}"
self.logger.error(f"OSS upload failed for {filename}: {error_msg}")
raise UploadError(
message=f"OSS upload failed with status {response.status_code}",
error_type=UploadErrorType.SERVER_ERROR,
status_code=response.status_code,
details={'response': response.text}
)
# 构建访问URL
if self.endpoint.startswith(('http://', 'https://')):
access_url = f"{self.endpoint}/{self.bucket_name}{object_key}"
else:
access_url = f"https://{self.bucket_name}.{self.endpoint}{object_key}"
# 构建图片元数据
image_metadata = ImageMetadata(
width=0, # OSS PUT不返回图片尺寸
height=0,
filename=filename,
size=len(file),
url=access_url,
delete_url=None # OSS需要单独的删除操作
)
# 记录上传成功的日志
self.logger.info(f"OSS upload successful for {filename}, URL: {access_url}")
return UploadResponse(
success=True,
code="success",
message="Upload to Aliyun OSS success",
data=image_metadata
)
except requests.RequestException as e:
error_msg = f"OSS upload request failed: {str(e)}"
self.logger.error(f"OSS upload request failed for {filename}: {error_msg}")
raise UploadError(
message=error_msg,
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except UploadError:
# UploadError 已经被记录了,直接重新抛出
raise
except Exception as e:
error_msg = f"OSS upload failed: {str(e)}"
self.logger.error(f"OSS upload unexpected error for {filename}: {error_msg}")
raise UploadError(
message=error_msg,
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class CloudFlareImgBedUploader(ImageUploader):
"""CloudFlare图床上传器"""
@@ -389,7 +621,7 @@ class ImageUploaderFactory:
credentials["secret_key"]
)
elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
api_url = credentials.get("api_url") or "https://www.picgo.net/api/1/upload"
return PicGoUploader(credentials["api_key"], api_url)
elif provider == "cloudflare_imgbed":
return CloudFlareImgBedUploader(
@@ -397,4 +629,13 @@ class ImageUploaderFactory:
credentials["base_url"],
credentials.get("upload_folder", ""),
)
elif provider == "aliyun_oss":
return AliyunOSSUploader(
credentials["access_key"],
credentials["access_key_secret"],
credentials["bucket_name"],
credentials["endpoint"],
credentials["region"],
credentials.get("use_internal", False)
)
raise ValueError(f"Unknown provider: {provider}")

View File

@@ -36,4 +36,13 @@ services:
interval: 10s # 每隔10秒检查一次
timeout: 5s # 每次检查的超时时间为5秒
retries: 3 # 重试3次失败后标记为 unhealthy
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
# adminer:
# image: adminer:latest
# container_name: gemini-balance-adminer
# restart: unless-stopped
# ports:
# - "8080:8080"
# depends_on:
# mysql:
# condition: service_healthy