Compare commits

...

83 Commits

Author SHA1 Message Date
snaily
cc36ba4c9e feat(config): 新增流式输出优化器开关配置
在环境变量示例文件(.env.example)和配置类(config.py)中新增 STREAM_OPTIMIZER_ENABLED 配置项,用于控制流式输出优化器的启用状态,默认设为 false

调整 Gemini 和 OpenAI 聊天服务的流式响应处理逻辑:
- 仅在流式优化器启用时(settings.STREAM_OPTIMIZER_ENABLED 为 true)
- 才会对文本内容执行流式输出优化处理
- 保持原有文本提取逻辑不变,仅增加配置条件判断

该变更使流式输出优化器变为可选功能,方便根据实际需求进行开关控制
2025-04-03 04:47:06 +08:00
snaily
baf643e884 feat: 新增请求超时配置及优化模型列表接口api_key获取方式
1. 新增功能:
   - 在`.env.example`中添加`TIME_OUT=300`配置项(包含中文注释)
   - 在`Settings`类中增加`TIME_OUT`字段(读取自`DEFAULT_TIMEOUT`)

2. 优化内容:
   - 生成配置:
     * 为`GenerationConfig`设置默认温度/TOP_P/TOP_K值
     * 移除`maxOutputTokens`默认值,改为可选传递
   - OpenAI请求:
     * 移除`max_tokens`默认值
     * 只有当`max_tokens`有值时才添加到请求payload
   - 日志优化:
     * 注释掉`stream_optimizer.py`中部分调试日志

3. 模型列表接口api_key获取方式
2025-04-03 03:12:59 +08:00
严浩
360bc9e48d feat(ci): 更新Docker发布工作流 2025-04-02 13:49:05 +08:00
snaily
c0a27d0542 Update README.md 2025-03-29 01:03:36 +08:00
snaily
84052a2179 feat(auth): 增强Gemini API的认证机制支持URL参数
- 将generate_content和stream_generate_content端点的认证依赖从verify_goog_api_key更改为verify_key_or_goog_api_key
- 使Gemini API同时支持URL参数中的key和请求头中的x-goog-api-key进行认证
- 提高API的灵活性,便于不同客户端集成
2025-03-28 23:44:40 +08:00
snaily
2e7ecd88b5 feat: 增强Gemini API tools参数处理
- 修改GeminiRequest模型,使tools字段支持单个工具对象或工具对象列表
- 在gemini_chat_service中添加类型转换逻辑,确保tools始终以列表形式处理
- 提高API的灵活性和兼容性
2025-03-28 20:50:01 +08:00
snaily
0b1f3dfc04 feat(auth): 支持x-goog-api-key请求头认证
- 添加verify_key_or_goog_api_key方法,支持同时验证URL参数中的key和请求头中的x-goog-api-key
- 更新models接口使用新的认证方法,提高与Google API客户端的兼容性
2025-03-28 19:27:42 +08:00
snaily
c691c7c1cf fix:当没有可用工具时返回空列表而非包含空字典的列表
在_build_tools函数中,当没有工具配置可用时(即tool为空字典),现在会返回空列表[]而不是[{}]。这个防御性编程修复可以避免向Gemini API发送无效的工具配置,防止可能的API调用错误。
2025-03-25 15:18:27 +08:00
snaily
97db7eebf1 chore:修改图片处理逻辑,统一使用base64编码
将_convert_image函数中对非data:image格式URL的处理方式从直接返回URL改为转换为base64编码的内联数据。这样无论图片是以data URI形式还是URL形式提供,都会统一转换为base64编码,确保与API交互时图片数据格式的一致性。
2025-03-25 13:23:17 +08:00
snaily
60dca70fcd fix: 改进图片显示和移除调试输出
优化图片链接格式,在图片前后添加空行以改善显示效果
注释掉OpenAI聊天服务中的调试打印语句
2025-03-22 03:38:45 +08:00
snaily
89b9f7919a feat: 添加对OpenAI工具调用功能的支持
改进消息转换器以处理OpenAI的tool_calls格式
添加JSON解析以正确转换函数调用参数
优化消息处理逻辑,增加更多空值检查
在流式响应中添加工具调用检测和处理
根据工具调用状态设置适当的finish_reason
2025-03-22 02:48:25 +08:00
Toddy
a8dc98ab6a fix tool use with function calling is unsupported error 2025-03-21 05:04:53 +00:00
snaily
b3a057b6ba refactor: 代码结构优化与常量化
将日志系统从 app/logger/ 移至 app/log/ 目录
将路由配置从 routers.py 重命名为 routes.py
将硬编码配置值移至 constants.py 中的默认常量
统一代码格式和导入排序
优化函数参数对齐方式
2025-03-20 21:59:18 +08:00
snaily
b14bb93d8f refactor: 项目结构优化与FastAPI生命周期更新
重构项目目录结构,提高代码组织性和可维护性

将schemas目录重命名为domain,更好地表达领域模型概念
将services目录细分为service/chat、service/image等子目录
将api目录重命名为router,更符合FastAPI惯例
创建utils目录存放通用工具函数
更新FastAPI应用程序生命周期管理

替换已弃用的on_event方法为推荐的lifespan事件处理器
添加应用程序关闭时的日志记录
代码质量改进

抽取常量到constants.py,减少硬编码值
添加helpers.py提供通用工具函数
优化配置管理,使用环境变量和默认值
完善文档字符串,提高代码可读性
2025-03-20 17:13:03 +08:00
snaily
8ca62707ea feat: 添加搜索模型配置并改进Markdown链接处理
在Dockerfile中添加SEARCH_MODELS环境变量,支持gemini-2.0-flash-exp和gemini-2.0-pro-exp模型
改进message_converter中的图片链接正则表达式
2025-03-19 19:56:50 +08:00
Toddy
21444ed6c7 chore: 统一从model_service读取模型列表 2025-03-18 18:05:00 +00:00
Toddy
ba292dbedd chore: 规范变量名 2025-03-18 17:54:18 +00:00
snaily
6ba58ce9d1 fix: 重构图片MIME类型转换逻辑
将"image/jpg"到"image/jpeg"的MIME类型转换逻辑从_convert_image函数移至_get_mime_type_and_data函数,避免代码重复并提高一致性。这确保了MIME类型的标准化处理发生在数据提取的同一位置。
2025-03-18 21:50:27 +08:00
snaily
16f16a3ae9 Merge branch 'pr/yangtb2024/13' 2025-03-18 21:46:34 +08:00
snaily
26dcb64687 fix: 将image/jpg MIME类型转换为标准的image/jpeg
修复了图像转换过程中的MIME类型处理,确保当遇到非标准的"image/jpg"类型时,将其转换为标准的"image/jpeg"类型。这样可以提高与接收图像数据的API和系统的兼容性
2025-03-18 21:35:19 +08:00
yangtb2024
df88492113 将chat-bison-001、text-bison-001和embedding-gecko-001添加到FILTERED_MODELS列表 2025-03-18 15:21:29 +08:00
yangtb2024
851bb9c09b 将 filtered_models 从硬编码改为可配置参数
1. 在 config.py 中添加 FILTERED_MODELS 配置项
2. 在 .env.example 中添加 FILTERED_MODELS 示例
3. 修改 model_service.py 以使用配置的过滤模型列表
4. 优化模型过滤逻辑
2025-03-18 14:47:58 +08:00
yangtb2024
0cac178572 Merge branch 'snailyp:main' into model 2025-03-18 12:44:09 +08:00
snaily
67c85c994a Merge pull request #14 from cr-zhichen/main
fix: 更新Cloudflare ImgBed上传请求URL,新增uploadNameType参数,以保持正确的目录结构命名。
2025-03-17 15:24:39 +08:00
cr-zhichen
ee979dd568 Merge branch 'main' of https://github.com/cr-zhichen/gemini-balance 2025-03-17 07:12:43 +00:00
cr-zhichen
e79a1ba56c feat: 更新CloudFlare ImgBed上传请求URL,新增uploadNameType参数,以保持正确的日期命名目录结构。 2025-03-17 07:10:21 +00:00
yangtb2024
016e6e06ee Filter out vision-based Gemini models from model list 2025-03-17 13:56:01 +08:00
snaily
8779a5f0b3 feat: 添加对 image-generation 模型的支持
在 gemini_chat_service 和 openai_chat_service 中添加对 "-image-generation" 后缀模型的支持
确保 image-generation 模型与 image 模型有相同的处理逻辑
2025-03-16 23:53:53 +08:00
cr-zhichen
89f2825ac7 feat: 新增对CloudFlare ImgBed的支持,更新环境变量和文档 2025-03-16 04:39:40 +00:00
snaily
985a12554d fix:修改OpenAI消息转换器中assistant消息处理逻辑,将特殊处理的目标从最后一条消息调整为倒数第二条消息。 2025-03-15 21:18:20 +08:00
snaily
37a7a140fc feat:改进消息转换器中的图像处理和消息分割逻辑
添加 _get_mime_type_and_data 函数从 base64 字符串中提取 MIME 类型和数据
修改 _convert_image 函数使用动态检测的 MIME 类型,而非硬编码
将 _process_text_with_image 中的 MIME 类型从 image/jpeg 改为 image/png
简化异常处理逻辑
优化 OpenAIMessageConverter 中的消息分割逻辑,仅对最后一个 assistant 消息进行分割处理
2025-03-15 21:11:10 +08:00
zhanghaoyu
28e67cc3fa 1. modify IMAGE_URL_PATTERN
2. modify import
2025-03-15 12:37:56 +08:00
zhanghaoyu7
d99a0bde93 feat: 新增图文上下文同步 2025-03-14 16:29:03 +08:00
snaily
cb5cd92041 fix: 修正Dockerfile中TOOLS_CODE_EXECUTION_ENABLED环境变量的拼写错误
将TOOLS_CODE_EXECUTION_ENABLED环境变量的值从"fasle"更正为"false",修复了拼写错误。
2025-03-14 13:46:31 +08:00
snaily
0be85e9536 feat(gemini_routes): 添加deepcopy导入
在gemini_routes.py中添加了Python标准库copy模块中的deepcopy函数导入,用于创建对象的深度副本,确保数据操作过程中不会意外修改原始对象。
2025-03-14 13:43:17 +08:00
Toddy
632dee38b3 check model before send request 2025-03-14 04:11:21 +00:00
Toddy
16c28bf1ba combine multiple system instructions into one 2025-03-14 02:55:29 +00:00
snaily
71af1db330 feat: 添加Gemini图像生成与处理功能
主要更新:

添加图像模型支持

新增MODEL_IMAGE配置项
在模型列表中添加gemini-2.0-flash-exp-image模型
修改ModelService以支持图像模型
增强图像处理能力

添加PicGoUploader类用于图像上传
实现图像响应处理逻辑(_extract_image_data)
支持base64图像数据的解码与上传
优化请求与响应处理

为图像模型添加特殊处理逻辑
修改API客户端以支持图像模型
更新GeminiRequest默认值
安全性调整

将TOOLS_CODE_EXECUTION_ENABLED默认设置为false
2025-03-14 00:27:23 +08:00
snaily
fb523f4a2e feat: 将 StreamOptimizer 参数改为可配置
将 StreamOptimizer 中的硬编码参数改为通过配置文件可配置的参数,提高了系统的灵活性。具体修改包括:

在 .env.example 中添加 stream_optimizer 相关配置参数
在 app/core/config.py 中添加对应的配置项
修改 app/services/chat/stream_optimizer.py 从配置中读取参数
在 README.md 中添加流式输出优化配置的详细说明
2025-03-06 16:56:01 +08:00
snaily
40e5ffa5f4 chore: Adjust StreamOptimizer parameters for improved performance
- Reduced long_text_threshold from 100 to 50 characters
- Decreased chunk_size from 10 to 5

These changes aim to optimize the streaming output for better user experience
and responsiveness, particularly for medium-length texts.
2025-03-06 16:45:35 +08:00
snaily
0871548b07 feat: 添加流式输出优化器以改善聊天体验
新增StreamOptimizer类用于优化API响应的流式输出
实现智能延迟调整算法,根据文本长度动态计算延迟时间
添加长文本分块输出功能,提高大段文本的显示效果
将优化器集成到Gemini和OpenAI聊天服务中
优化后的输出更接近自然打字效果,提升用户体验
2025-03-06 15:53:58 +08:00
snaily
5a44a76c48 Merge remote-tracking branch 'BetterAndBetterII/main' 2025-03-03 18:45:56 +08:00
Toddy
7b5b6c7d4c if role is tool then set to user 2025-03-03 08:23:04 +00:00
Yuzhong Zhang
68ed4da789 Update Dockerfile 2025-03-03 14:09:45 +08:00
Yuzhong Zhang
cdbca7ec62 优化dockerfile,增加docker-compose,async openai 2025-03-03 13:55:09 +08:00
Yuzhong Zhang
48d58ef2e8 异步生成完成 2025-03-03 13:41:06 +08:00
snaily
88d483c1ef Merge pull request #4 from toddyoe/main
chore: add system instruction to enhance compliance with function call
2025-02-27 19:17:39 +08:00
Toddy
8d48db026c chore: add system instruction to enhance compliance with function call 2025-02-27 10:35:25 +00:00
snaily
a592269198 Merge pull request #3 from toddyoe/main
feat: support function call
2025-02-27 16:14:50 +08:00
Toddy
18a5fe6109 fix: adapt gemini format 2025-02-27 07:35:12 +00:00
Toddy
348cbbdf2a feat: support function call 2025-02-27 05:36:39 +00:00
yinpeng
64235143dd ci: 精简 release workflow 文件
- 移除了 release-drafter 相关的步骤
- 保留了代码检出和创建 Release 的步骤
- 简化了工作流结构,提高了可读性
2025-02-15 01:06:32 +08:00
yinpeng
d566c28fa2 feat(gemini): 添加 API 密钥验证功能
- 在 gemini_routes.py 中添加 verify_key 路由,用于验证 API 密钥的有效性
- 在 keys_status 页面中添加验证按钮和相关逻辑
- 优化 keys_status 页面的样式,增加密钥验证相关 CSS 类
- 在 config.py 中添加 TEST_MODEL 设置,用于密钥验证测试
2025-02-15 01:00:47 +08:00
yinpeng
c1893d918e build: 更新发布流程并移除 release-drafter 配置
- 删除了 release-drafter.yml 文件,不再使用 release-drafter 自动生成发布说明
- 更新了 release.yml 工作流,移除了自动填充发布说明的步骤
- 保留了创建 ZIP 文件和上传构建文件的步骤,但标记为可选
2025-02-14 01:55:44 +08:00
yinpeng
4a02475cc1 ci: 优化发布流程,使用 release-drafter 自动生成发布说明 2025-02-14 01:46:24 +08:00
yinpeng
6e55a0985c ci: 优化发布流程,添加自动生成发布说明和资源打包功能 2025-02-14 01:40:20 +08:00
yinpeng
7b433aab91 refactor(static): 将 CSS 和 JS 代码分离到单独的文件中
- 将 auth.html 中的 CSS 代码提取到 auth.css 文件中
- 将 auth.html 中的 JS 代码提取到 auth.js 文件中
- 更新 auth.html,引入外部的 CSS 和 JS 文件
- 新增 keys_status.css 和 keys_status.js 文件,用于 keys_status 页面
2025-02-14 00:21:28 +08:00
yinpeng
fc7280bb18 feat: 优化滚动按钮显示逻辑,监听容器高度变化自动切换 2025-02-13 01:05:30 +08:00
yinpeng
8d9c99bda2 feat: 优化密钥状态页面滚动体验,添加容器滚动和渐变按钮样式 2025-02-13 00:49:44 +08:00
yinpeng
ab701f9415 docs: 完善 Web 界面功能文档,补充界面特性和交互细节 2025-02-12 23:40:05 +08:00
yinpeng
c3e0d4b64f feat: 添加页面底部版权信息和作者链接 2025-02-12 23:34:18 +08:00
yinpeng
5b7f4de63c feat: 优化密钥状态页面交互体验,添加分组折叠和刷新功能 2025-02-12 18:55:44 +08:00
yinpeng
ede27a5d70 refactor: 移除 retry_handler 中未使用的 KeyManager 导入 2025-02-12 17:48:09 +08:00
yinpeng
5a4619444b fix: 修复 Gemini 多段文本响应内容拼接问题 2025-02-12 17:47:03 +08:00
yinpeng
b3851441f1 refactor: 优化 RetryHandler 装饰器以支持动态 KeyManager 注入 2025-02-12 17:10:02 +08:00
yinpeng
44f956e4e4 feat: Add PWA support with manifest and ServiceWorker integration
- Mounted static files directory to serve PWA assets like manifest.json and ServiceWorker scripts.
- Updated `auth.html` and `keys_status.html` templates:
  - Added `<link>` for manifest and icons to support Progressive Web App (PWA) features.
  - Added meta tags for theme color and Apple web app capabilities.
  - Integrated ServiceWorker registration script for offline capabilities.
2025-02-12 16:20:34 +08:00
yinpeng
3aa4384b9d feat: Add responsive styles for auth and keys status pages
- Implement media queries to improve layout and UI for smaller screen sizes on `auth.html` and `keys_status.html`.
- Adjust container widths, font sizes, padding, and other styles for screen widths below 768px and 480px.
- Enhance mobile usability by making elements stack vertically, resizing fonts, and optimizing spacing for better readability and interaction.
2025-02-12 15:46:37 +08:00
yinpeng
6db4b56186 Refactor keys_status.html for improved layout and scrolling behavior
- Removed duplicated padding and simplified CSS for `body`, ensuring proper spacing with 20px padding.
- Adjusted `.container` styles:
  - Removed custom scrollbar styles and overflow-related attributes.
  - Centered the element with `margin: 20px auto`.
- Updated scroll behavior:
  - Changed scroll functions to operate on `window` instead of `.container`.
  - Modified event listeners to use `window` for detecting scroll events.
- Cleaned up redundant or unused styles and improved readability.
2025-02-12 15:30:44 +08:00
yinpeng
8e77773d5a Enhance UI/UX for keys_status.html
- Added smooth scroll functionality with "Scroll to Top" and "Scroll to Bottom" buttons.
- Introduced a `scroll-buttons` section with styled buttons for scrolling.
- Improved `#copyStatus` styling for better visibility and alignment.
- Adjusted `.container` to support scrollable content with hidden scrollbars and a max-height.
- Ensured proper z-index for new elements to prevent overlapping issues.
- Enhanced hover and active states for scroll buttons to improve user experience.
- Added event listeners to dynamically show/hide scroll buttons based on user scroll position.
2025-02-12 15:16:22 +08:00
yinpeng
343f40476a feat: Improve UI/UX for API Key Status page and add enhancements
- Updated the overall design aesthetics of the authentication page.
  - Added `fadeIn`, `slideDown`, `slideUp`, and `shake` animations for better user interaction.
  - Improved error message styling with a subtle background, padding, and animation.

- Enhanced "API Key Status" page:
  - Implemented new theme with gradient backgrounds and glassmorphism effect.
  - Redesigned headings with underlines and improved hierarchy.
  - Added FontAwesome icons to improve the visual appeal and clarity (e.g., checkmarks, warnings, keys).
  - Applied better spacing, padding, and hover effects to list items and buttons.
  - Introduced animations for key lists to create fluid transitions on page load.
  - Differentiated valid and invalid keys using badges with appropriate colors and icons.

- Copy Key Interaction:
  - Improved key copying functionality:
    - Added animations and hover effects to "Copy" buttons.
    - Updated the copied key selector logic to target `.key-text` for cleaner code.
    - Changed copy confirmation message for better clarity.
  - Styled the copy success message (`#copyStatus`) to appear fixed at the bottom with a blur effect.

- Key List Enhancements:
  - Added fail count badges for individual keys with red warning styles.
  - Styled buttons for batch copying to display icons alongside text, matching the overall design.

- Accessibility and Readability:
  - Refactored text sizes, weights, and alignments for smoother readability.
  - Enhanced color contrast and alignment for better accessibility.

Notes:
- New CSS animations have been smoothly integrated with no breaking changes.
- All changes prioritize maintaining current functionality while enhancing user experience.
2025-02-12 14:46:34 +08:00
yinpeng
e024d55006 feat: update workflows for docker and release processes
- Updated `.github/workflows/docker-publish.yml`:
  - Commented out the branch trigger for `main` in the `push` event to allow only tag-based Docker builds (tags like `v*.*.*`).

- Updated `.github/workflows/release.yml`:
  - Removed default release body template containing placeholder release notes. This simplifies the release creation process and avoids predefined content.
  - No functional changes to release asset upload configurations, minor format adjustment to ensure no missing newline at the file end.
2025-02-12 14:20:46 +08:00
yinpeng
17f1355099 feat: 增强应用日志记录并优化错误处理 2025-02-11 21:32:21 +08:00
yinpeng
f994c5d66d feat: 添加python-multipart依赖支持表单数据处理 2025-02-11 21:00:08 +08:00
yinpeng
e6bf45d778 refactor: 移除静态文件配置和相关依赖 2025-02-11 20:55:13 +08:00
yinpeng
8c9b802016 feat: 添加Web验证页面并优化密钥管理功能 2025-02-11 20:45:49 +08:00
yinpeng
d1f8a98ad0 feat: 支持在图片生成提示词中通过标记控制参数 2025-02-11 06:10:55 +08:00
yinpeng
30858937b5 feat: 支持图片生成响应格式切换并优化Markdown渲染 2025-02-11 05:13:36 +08:00
yinpeng
cb4d26778e docs: 完善环境变量配置文档并优化分类说明 2025-02-11 04:50:51 +08:00
yinpeng
0aefd4c03a feat: 添加OpenAI消息转换器组件 2025-02-11 04:27:17 +08:00
yinpeng
97b9b99235 feat: 根据模型类型选择不同的API密钥处理聊天请求 2025-02-11 04:20:28 +08:00
yinpeng
34a98389f5 fix: 修复图片生成模型重复添加的问题 2025-02-11 02:55:39 +08:00
yinpeng
4a73592f0e chore: 设置默认图片生成模型为imagen-3.0-generate-002 2025-02-11 02:46:23 +08:00
yinpeng
a354c9ebb1 chore: 升级Docker基础镜像至Python 3.10 2025-02-11 02:41:01 +08:00
53 changed files with 4066 additions and 999 deletions

View File

@@ -1,15 +1,31 @@
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TOOLS_CODE_EXECUTION_ENABLED=true
TEST_MODEL=gemini-1.5-flash
IMAGE_MODELS=["gemini-2.0-flash-exp"]
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta
MAX_FAILURES=10
# 请求超时时间(秒)
TIME_OUT=300
#########################image_generate 相关配置###########################
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_OPTIMIZER_ENABLED=false
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
##########################################################################

View File

@@ -2,8 +2,6 @@ name: Docker Image CI
on:
push:
branches: [ "main" ]
tags: [ 'v*.*.*' ]
pull_request:
branches: [ "main" ]
@@ -43,20 +41,30 @@ jobs:
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=raw,value=latest,enable={{is_default_branch}}
# https://github.com/docker/metadata-action/tree/v5/?tab=readme-ov-file#semver
# Event: push, Ref: refs/head/main, Tags: main
# Event: push tag, Ref: refs/tags/v1.2.3, Tags: 1.2.3, 1.2, 1, latest
# Event: push tag, Ref: refs/tags/v2.0.8-rc1, Tags: 2.0.8-rc1
type=ref,event=branch
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,format=long
type=semver,pattern={{major}}
labels: |
org.opencontainers.image.description=OpenAI API Compatible Server
org.opencontainers.image.source=${{ github.event.repository.html_url }}
- name: Build and push Docker image
uses: docker/build-push-action@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Build and push
uses: docker/build-push-action@v6
with:
file: Dockerfile
context: .
platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }}
load: false
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=${{ github.workflow }}
cache-to: type=gha,scope=${{ github.workflow }}

View File

@@ -6,9 +6,10 @@ on:
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0
jobs:
release:
update-release-draft:
permissions:
contents: write # 添加写入权限
contents: write
pull-requests: write
runs-on: ubuntu-latest
steps:
# Step 1: 检出代码库
@@ -24,10 +25,6 @@ jobs:
with:
tag_name: ${{ github.ref_name }}
release_name: ${{ github.ref_name }}
body: |
## Release Notes
- 自动发布版本。
- 请根据需求更新对应内容。
draft: false
prerelease: false
@@ -45,4 +42,4 @@ jobs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./gemini-balance.zip # 替换为你的构建文件路径
asset_name: gemini-balance.zip # 替换为你的文件名
asset_content_type: application/zip
asset_content_type: application/zip

View File

@@ -1,17 +1,18 @@
FROM python:3.9-slim
FROM python:3.10-slim
WORKDIR /app
# 复制所需文件到容器中
COPY ./app /app/app
COPY ./requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=true
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
# Expose port
EXPOSE 8000

280
README.md
View File

@@ -1,14 +1,14 @@
# 🚀 FastAPI OpenAI (Gemini) 代理服务
# 🚀 Gemini 代理服务支持openai/gemini格式
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
## 📝 项目简介
本项目是一个基于 FastAPI 框架开发的高性能、易于部署的 OpenAI 和 Gemini API 代理服务。它不仅兼容 OpenAI 的 API 接口,还支持 Google 的 Gemini 模型,为用户提供灵活的模型选择。该代理服务内置了多 API Key 轮询、负载均衡、自动重试、访问控制Bearer Token 认证)、流式响应等功能,旨在简化 AI 应用的开发和部署流程。
本项目是一个基于 FastAPI 框架开发的高性能、易于部署的Gemini OpenAI兼容 和 Gemini API 代理服务。它不仅兼容 OpenAI 的 API 接口,还支持 Google 的 Gemini 原生接口。该代理服务内置了多 API Key 轮询、负载均衡、自动重试、访问控制Bearer Token 认证)、流式响应等功能,旨在简化 AI 应用的开发和部署流程。
**核心功能与优势:**
- **多模型支持**: 无缝切换 OpenAI 和 Gemini 模型
- **多协议支持**: 无缝切换 OpenAI兼容 和 Gemini 协议
- **智能 API Key 管理**: 自动轮询多个 API Key实现负载均衡和故障转移。
- **安全访问控制**: 使用 Bearer Token 进行身份验证,保护 API 访问。
- **流式响应支持**: 提供实时的流式数据传输,提升用户体验。
@@ -51,29 +51,145 @@
3. **配置**:
创建 `.env` 文件,并配置以下环境变量:
创建 `.env` 文件,并按以下分类配置环境变量:
```env
API_KEYS=["your-gemini-api-key-1", "your-gemini-api-key-2"] # 你的 Gemini API 密钥列表
ALLOWED_TOKENS=["your-access-token-1", "your-access-token-2"] # 允许访问的 Token 列表
BASE_URL="https://generativelanguage.googleapis.com/v1beta" # Gemini API 基础 URL, 保持默认即可
MODEL_SEARCH=["gemini-2.0-flash-exp"] # 启用搜索功能的模型列表
TOOLS_CODE_EXECUTION_ENABLED=false # 是否启用代码执行工具, 默认为 false
SHOW_SEARCH_LINK=true # 是否显示搜索链接
SHOW_THINKING_PROCESS=true # 是否显示思考过程
AUTH_TOKEN="" # 备用token, 如果不设置, 默认为 ALLOWED_TOKENS 的第一个
MAX_FAILURES=3 # 允许单个key失败的次数
# 基础配置
BASE_URL="https://generativelanguage.googleapis.com/v1beta" # Gemini API 基础 URL默认无需修改
MAX_FAILURES=3 # 允许单个key失败的次数默认3次
# 认证与安全配置
API_KEYS=["your-gemini-api-key-1", "your-gemini-api-key-2"] # Gemini API 密钥列表,用于负载均衡
ALLOWED_TOKENS=["your-access-token-1", "your-access-token-2"] # 允许访问的 Token 列表
AUTH_TOKEN="" # 超级管理员token具有所有权限默认使用 ALLOWED_TOKENS 的第一个
# 模型功能配置
TEST_MODEL="gemini-1.5-flash" # 用于测试密钥是否可用的模型名
SEARCH_MODELS=["gemini-2.0-flash-exp"] # 支持搜索功能的模型列表
IMAGE_MODELS=["gemini-2.0-flash-exp"] # 支持绘图功能的模型列表
TOOLS_CODE_EXECUTION_ENABLED=false # 是否启用代码执行工具默认false
SHOW_SEARCH_LINK=true # 是否在响应中显示搜索结果链接默认true
SHOW_THINKING_PROCESS=true # 是否显示模型思考过程默认true
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"] # 被禁用的模型列表
# 图片生成配置
PAID_KEY="your-paid-api-key" # 付费版API Key用于图片生成等高级功能
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型默认使用imagen-3.0
# 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms、picgo、cloudflare_imgbed
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key可在项目后台设置若无鉴权则可直接置空。
# stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
```
- `API_KEYS`: 你的 Gemini API 密钥列表,支持多个 Key 轮询。
- `ALLOWED_TOKENS`: 允许访问的 Token 列表,用于 API 认证。
- `BASE_URL`: Gemini API 的基础 URL通常不需要修改。
- `MODEL_SEARCH`: 启用搜索功能的模型列表。
- `TOOLS_CODE_EXECUTION_ENABLED`: 是否启用代码执行工具, 默认为 `false`。
- `SHOW_SEARCH_LINK`: 是否显示搜索结果链接(当使用搜索模型时)。
- `SHOW_THINKING_PROCESS`: 是否显示模型的"思考"过程(对于某些模型)。
- `AUTH_TOKEN`: 主鉴权token(权限较大,注意保管), 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。
- `MAX_FAILURES`: 允许单个 API Key 失败的次数,超过此次数后该 Key 将被标记为无效。
### 配置说明
#### 基础配置
- `BASE_URL`: Gemini API 的基础 URL
- 默认值: `https://generativelanguage.googleapis.com/v1beta`
- 说明: 通常无需修改,除非 API 地址发生变化
- `MAX_FAILURES`: API Key 允许的最大失败次数
- 默认值: `3`
- 说明: 超过此次数后Key 将被暂时标记为无效
#### 认证与安全配置
- `API_KEYS`: Gemini API 密钥列表
- 格式: JSON 数组字符串
- 用途: 支持多个 Key 轮询,实现负载均衡
- 建议: 至少配置 2 个 Key 以保证服务可用性
- `ALLOWED_TOKENS`: 访问令牌列表
- 格式: JSON 数组字符串
- 用途: 用于客户端认证
- 安全提示: 请使用足够复杂的令牌
- `AUTH_TOKEN`: 超级管理员令牌
- 可选配置,留空则使用 ALLOWED_TOKENS 的第一个
- 具有查看 API Key 状态等特权操作权限
#### 模型功能配置
- `TEST_MODEL`: 用于测试密钥可用性的模型
- 默认值: `"gemini-1.5-flash"`
- `SEARCH_MODELS`: 搜索功能支持的模型
- 默认值: `["gemini-2.0-flash-exp"]`
- 说明: 仅列表中的模型可使用搜索功能
- `IMAGE_MODELS`: 绘图功能支持的模型
- 默认值: `["gemini-2.0-flash-exp"]`
- 说明: 仅列表中的模型可使用绘图功能
- `FILTERED_MODELS`: 被禁用的模型列表
- 默认值: `["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]`
- 说明: 列表中的模型将被禁用
- `TOOLS_CODE_EXECUTION_ENABLED`: 代码执行功能
- 默认值: `false`
- 安全提示: 生产环境建议禁用
- `SHOW_SEARCH_LINK`: 搜索结果链接显示
- 默认值: `true`
- 用途: 控制搜索结果中是否包含原始链接
- `SHOW_THINKING_PROCESS`: 思考过程显示
- 默认值: `true`
- 用途: 显示模型的推理过程,便于调试
#### 图片生成配置
- `PAID_KEY`: 付费版 API Key
- 用途: 用于图片生成等高级功能
- 说明: 需要单独申请的付费版 Key
- `CREATE_IMAGE_MODEL`: 图片生成模型
- 默认值: `imagen-3.0-generate-002`
- 说明: 当前支持的最新图片生成模型
#### 图片上传配置
- `UPLOAD_PROVIDER`: 图片上传服务提供商
- 默认值: `smms`
- 可选值: `smms`, `picgo`, `cloudflare_imgbed`
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
- `SMMS_SECRET_TOKEN`: SM.MS API Token
- 用途: 用于图片上传到 SM.MS 图床的身份验证。
- 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取。
- `PICGO_API_KEY`: PicGo API Key
- 用途: 用于图片上传到 PicGo 图床的身份验证。
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权则留空即可。
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
#### 流式输出优化配置
- `STREAM_MIN_DELAY`: 最小延迟时间
- 默认值: `0.016`(秒)
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
- `STREAM_MAX_DELAY`: 最大延迟时间
- 默认值: `0.024`(秒)
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
- 默认值: `10`(字符)
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
- 默认值: `50`(字符)
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
- `STREAM_CHUNK_SIZE`: 长文本分块大小
- 默认值: `5`(字符)
- 说明: 长文本分块输出时,每个块的大小
### ▶️ 运行
@@ -109,13 +225,30 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer <your-token>`,其中 `<your-token>` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个或者 `AUTH_TOKEN`。
### 获取模型列表
### API 路由
本服务提供两种API路由
1. **OpenAI 兼容路由** (推荐)
- 基础路径: `/v1`
- 完全兼容OpenAI API格式
- 支持所有Gemini模型
2. **Gemini 原生路由**
- 基础路径: `/gemini/v1beta` 或 `/v1beta`
- 遵循Google原生API格式
- 适用于需要直接使用Gemini API的场景
### OpenAI兼容路由
#### 获取模型列表
- **URL**: `/v1/models`
- **Method**: `GET`
- **Header**: `Authorization: Bearer <your-token>`
- **Response**: 返回支持的所有模型列表,包括最新的`gemini-2.0-flash-exp-search`等模型
### 聊天补全 (Chat Completions)
#### 聊天补全 (Chat Completions)
- **URL**: `/v1/chat/completions`
- **Method**: `POST`
@@ -130,7 +263,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
"content": "你好"
}
],
"model": "gemini-1.5-flash-002",
"model": "gemini-1.5-flash",
"temperature": 0.7,
"stream": false,
"tools": [],
@@ -141,11 +274,34 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
}
```
- `messages`: 消息列表,格式与 OpenAI API 相同
- `model`: 模型名称,例如 `gemini-1.5-flash-002`。
- `stream`: 是否开启流式响应,`true` 或 `false`。
- `tools`: 使用的工具列表。
- 其他参数:与 OpenAI API 兼容的参数,如 `temperature`, `max_tokens` 等。
- `messages`: 消息列表,格式与 OpenAI API 相同
- `model`: 模型名称,支持所有Gemini模型包括:
- `gemini-1.5-flash`: 快速响应模型
- `gemini-2.0-flash-exp`: 实验性快速响应模型
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
- `stream`: 是否开启流式响应,`true` 或 `false`
- `tools`: 使用的工具列表
- 其他参数:与 OpenAI API 兼容的参数,如 `temperature`, `max_tokens` 等
### Gemini原生路由
#### 获取模型列表
- **URL**: `/gemini/v1beta/models` 或 `/v1beta/models`
- **Method**: `GET`
- **Header**: `Authorization: Bearer <your-token>`
#### 生成内容
- **URL**: `/gemini/v1beta/models/{model_name}:generateContent`
- **Method**: `POST`
- **Header**: `Authorization: Bearer <your-token>`
#### 流式生成内容
- **URL**: `/gemini/v1beta/models/{model_name}:streamGenerateContent`
- **Method**: `POST`
- **Header**: `Authorization: Bearer <your-token>`
### 获取词向量 (Embeddings)
@@ -169,12 +325,47 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
- **URL**: `/health`
- **Method**: `GET`
### 获取 API Key 列表
### Web界面功能
#### 验证页面 (auth.html)
- **URL**: `/auth`
- **说明**: 提供了一个简洁的Web界面用于验证访问令牌
- **功能特点**:
- 现代化的渐变背景设计
- 响应式布局,完美支持移动端
- 毛玻璃效果的卡片设计
- 优雅的动画效果(淡入、滑动、悬浮)
- 安全的令牌验证机制
- 清晰的错误提示功能
- PWA支持可安装为本地应用
- 底部版权信息和GitHub链接
- 支持暗色主题适配
#### API密钥状态管理 (keys_status.html)
- **URL**: `/v1/keys/list`
- **Method**: `GET`
- **Header**: `Authorization: Bearer <your-auth-token>`
- **说明**: 只有使用 `AUTH_TOKEN` 才能访问此接口, 用于获取有效和无效的 API Key 列表。
- **功能特点**:
- 只有使用 `AUTH_TOKEN` 才能访问此接口
- 分类展示API密钥状态有效/无效)
- 可折叠的密钥列表分组
- 每个密钥显示:
- 状态标识(有效/无效)
- 密钥内容
- 失败次数统计
- 高级功能:
- 一键复制单个密钥
- 批量复制分组密钥JSON格式
- 实时刷新功能
- 回到顶部/底部快捷按钮
- 界面特性:
- 响应式设计,适配各种屏幕
- 优雅的动画效果
- 操作反馈(复制成功提示)
- PWA支持
- 暗色主题适配
### 图片生成 (Image Generation)
@@ -186,12 +377,34 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```json
{
"model": "dall-e-3",
"prompt": "汉服美女",
"prompt": "{n:2} {ratio:16:9} 汉服美女",
"n": 1,
"size": "1024x1024"
}
```
**Prompt参数说明:**
prompt支持通过特殊标记来控制生成参数
1. 图片数量控制:
- 格式: `{n:数量}`
- 示例: `{n:2} 一只可爱的猫` - 生成2张图片
- 取值范围: 1-4
- 说明: 如果在prompt中指定了n将覆盖请求body中的n参数
2. 图片比例控制:
- 格式: `{ratio:宽:高}`
- 示例: `{ratio:16:9} 一片森林` - 生成16:9比例的图片
- 支持的比例: "1:1"、"3:4"、"4:3"、"9:16"、"16:9"
- 说明: 如果指定了size参数将优先使用size对应的比例
3. 参数组合:
- 示例: `{n:2} {ratio:16:9} 一片美丽的森林` - 生成2张16:9比例的图片
- 说明: 这些参数标记会自动从prompt中移除不会影响实际的图片生成提示词
> 注意n的取值范围[1,4], ratio取值范围"1:1"、"3:4"、"4:3"、"9:16" 和 "16:9"
## 📚 代码结构
```plaintext
@@ -267,6 +480,7 @@ A: 请检查以下几点:
A: 在请求的 Body 中,将 `stream` 参数设置为 `true` 即可。
**Q: 如何启用代码执行工具?**
A: 在 `.env` 文件的 `TOOLS_CODE_EXECUTION_ENABLED` 变量中, 设置为 `true` 即可。
## 📄 许可证

View File

@@ -1,95 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.core.security import SecurityService
from app.schemas.gemini_models import GeminiRequest
from app.services.gemini_chat_service import GeminiChatService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat.retry_handler import RetryHandler
router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
logger = get_gemini_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(_=Depends(security_service.verify_key)):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental",
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
"topP": 0.95, "topK": 64, "maxTemperature": 2})
return models_json
@router.post("/models/{model_name}:generateContent")
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
try:
response = chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/models/{model_name}:streamGenerateContent")
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
try:
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")

57
app/config/config.py Normal file
View File

@@ -0,0 +1,57 @@
"""
应用程序配置模块
"""
from typing import List
from pydantic_settings import BaseSettings
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT
class Settings(BaseSettings):
"""应用程序配置"""
# API相关配置
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
TEST_MODEL: str = DEFAULT_MODEL
TIME_OUT: int = DEFAULT_TIMEOUT
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
# 图像生成相关配置
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
# 流式输出优化器配置
STREAM_OPTIMIZER_ENABLED: bool = False
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 设置默认AUTH_TOKEN如果未提供
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
class Config:
env_file = ".env"
# 创建全局配置实例
settings = Settings()

71
app/core/application.py Normal file
View File

@@ -0,0 +1,71 @@
"""
应用程序工厂模块负责创建和配置FastAPI应用程序实例
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from app.config.config import settings
from app.log.logger import get_application_logger
from app.middleware.middleware import setup_middlewares
from app.exception.exceptions import setup_exception_handlers
from app.router.routes import setup_routers
from app.service.key.key_manager import get_key_manager_instance
from app.core.initialization import initialize_app
logger = get_application_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用程序生命周期管理器
Args:
app: FastAPI应用实例
"""
# 启动事件
logger.info("Application starting up...")
try:
# 初始化KeyManager
await get_key_manager_instance(settings.API_KEYS)
logger.info("KeyManager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize KeyManager: {str(e)}")
raise
yield # 应用程序运行期间
# 关闭事件
logger.info("Application shutting down...")
def create_app() -> FastAPI:
"""
创建并配置FastAPI应用程序实例
Returns:
FastAPI: 配置好的FastAPI应用程序实例
"""
# 初始化应用程序
initialize_app()
# 创建FastAPI应用
app = FastAPI(
title="Gemini Balance API",
description="Gemini API代理服务支持负载均衡和密钥管理",
version="1.0.0",
lifespan=lifespan
)
# 配置静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# 配置中间件
setup_middlewares(app)
# 配置异常处理器
setup_exception_handlers(app)
# 配置路由
setup_routers(app)
return app

View File

@@ -1,29 +0,0 @@
from pydantic_settings import BaseSettings
from typing import List
class Settings(BaseSettings):
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = ""
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
def __init__(self):
super().__init__()
if not self.AUTH_TOKEN:
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
class Config:
env_file = ".env"
settings = Settings()

41
app/core/constants.py Normal file
View File

@@ -0,0 +1,41 @@
"""
常量定义模块
"""
# API相关常量
API_VERSION = "v1beta"
DEFAULT_TIMEOUT = 300 # 秒
# 模型相关常量
SUPPORTED_ROLES = ["user", "model", "system"]
DEFAULT_MODEL = "gemini-1.5-flash"
DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 8192
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 40
DEFAULT_FILTER_MODELS = [
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001"
]
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
# 图像生成相关常量
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# 上传提供商
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
DEFAULT_UPLOAD_PROVIDER = "smms"
# 流式输出相关常量
DEFAULT_STREAM_MIN_DELAY = 0.016
DEFAULT_STREAM_MAX_DELAY = 0.024
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
DEFAULT_STREAM_CHUNK_SIZE = 5
# 正则表达式模式
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'

View File

@@ -0,0 +1,40 @@
"""
应用程序初始化模块
"""
from pathlib import Path
from typing import List
from app.log.logger import get_initialization_logger
logger = get_initialization_logger()
def ensure_directories_exist(directories: List[str]) -> None:
"""
确保指定的目录存在,如果不存在则创建
Args:
directories: 要确保存在的目录列表
"""
for directory in directories:
try:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.info(f"Ensured directory exists: {directory}")
except Exception as e:
logger.error(f"Failed to create directory {directory}: {str(e)}")
def initialize_app() -> None:
"""
初始化应用程序,确保所需的目录和文件都存在
"""
# 确保必要的目录存在
required_directories = [
"app/static/css",
"app/static/js",
"app/static/icons",
"app/templates",
]
ensure_directories_exist(required_directories)
logger.info("Application initialization completed")

View File

@@ -1,10 +1,17 @@
from fastapi import HTTPException, Header
from typing import Optional
from app.core.logger import get_security_logger
from fastapi import Header, HTTPException
from app.config.config import settings
from app.log.logger import get_security_logger
logger = get_security_logger()
def verify_auth_token(token: str) -> bool:
return token == settings.AUTH_TOKEN
class SecurityService:
def __init__(self, allowed_tokens: list, auth_token: str):
self.allowed_tokens = allowed_tokens
@@ -17,7 +24,7 @@ class SecurityService:
return key
async def verify_authorization(
self, authorization: Optional[str] = Header(None)
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing Authorization header")
@@ -36,19 +43,26 @@ class SecurityService:
return token
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
async def verify_goog_api_key(
self, x_goog_api_key: Optional[str] = Header(None)
) -> str:
"""验证Google API Key"""
if not x_goog_api_key:
logger.error("Missing x-goog-api-key header")
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
if (
x_goog_api_key not in self.allowed_tokens
and x_goog_api_key != self.auth_token
):
logger.error("Invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
return x_goog_api_key
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
async def verify_auth_token(
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing auth_token header")
raise HTTPException(status_code=401, detail="Missing auth_token header")
@@ -58,3 +72,22 @@ class SecurityService:
raise HTTPException(status_code=401, detail="Invalid auth_token")
return token
async def verify_key_or_goog_api_key(
self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
) -> str:
"""验证URL中的key或请求头中的x-goog-api-key"""
# 如果URL中的key有效直接返回
if key in self.allowed_tokens or key == self.auth_token:
return key
# 否则检查请求头中的x-goog-api-key
if not x_goog_api_key:
logger.error("Invalid key and missing x-goog-api-key header")
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
logger.error("Invalid key and invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
return x_goog_api_key

View File

@@ -1,163 +0,0 @@
import requests
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
from enum import Enum
from typing import Optional, Any
class UploadErrorType(Enum):
"""上传错误类型枚举"""
NETWORK_ERROR = "network_error" # 网络请求错误
AUTH_ERROR = "auth_error" # 认证错误
INVALID_FILE = "invalid_file" # 无效文件
SERVER_ERROR = "server_error" # 服务器错误
PARSE_ERROR = "parse_error" # 响应解析错误
UNKNOWN = "unknown" # 未知错误
class UploadError(Exception):
"""图片上传错误异常类"""
def __init__(
self,
message: str,
error_type: UploadErrorType = UploadErrorType.UNKNOWN,
status_code: Optional[int] = None,
details: Optional[dict] = None,
original_error: Optional[Exception] = None
):
"""
初始化上传错误异常
Args:
message: 错误消息
error_type: 错误类型
status_code: HTTP状态码
details: 详细错误信息
original_error: 原始异常
"""
self.message = message
self.error_type = error_type
self.status_code = status_code
self.details = details or {}
self.original_error = original_error
# 构建完整错误信息
full_message = f"[{error_type.value}] {message}"
if status_code:
full_message = f"{full_message} (Status: {status_code})"
if details:
full_message = f"{full_message} - Details: {details}"
super().__init__(full_message)
@classmethod
def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError":
"""
从HTTP响应创建错误实例
Args:
response: HTTP响应对象
message: 自定义错误消息
"""
try:
error_data = response.json()
details = error_data.get("data", {})
return cls(
message=message or error_data.get("message", "Unknown error"),
error_type=UploadErrorType.SERVER_ERROR,
status_code=response.status_code,
details=details
)
except Exception:
return cls(
message=message or "Failed to parse error response",
error_type=UploadErrorType.PARSE_ERROR,
status_code=response.status_code
)
class SmMsUploader(ImageUploader):
API_URL = "https://sm.ms/api/v2/upload"
def __init__(self, api_key: str):
self.api_key = api_key
def upload(self, file: bytes, filename: str) -> UploadResponse:
try:
# 准备请求头
headers = {
"Authorization": f"Basic {self.api_key}"
}
# 准备文件数据
files = {
"smfile": (filename, file, "image/png")
}
# 发送请求
response = requests.post(
self.API_URL,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if not result.get("success"):
raise UploadError(result.get("message", "Upload failed"))
# 转换为统一格式
data = result["data"]
image_metadata = ImageMetadata(
width=data["width"],
height=data["height"],
filename=data["filename"],
size=data["size"],
url=data["url"],
delete_url=data["delete"]
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(f"Upload request failed: {str(e)}")
except (KeyError, ValueError) as e:
# 处理响应解析错误
raise UploadError(f"Invalid response format: {str(e)}")
except Exception as e:
# 处理其他未预期的错误
raise UploadError(f"Upload failed: {str(e)}")
class QiniuUploader(ImageUploader):
def __init__(self, access_key: str, secret_key: str):
self.access_key = access_key
self.secret_key = secret_key
def upload(self, file: bytes, filename: str) -> UploadResponse:
# 实现七牛云的具体上传逻辑
pass
class ImageUploaderFactory:
@staticmethod
def create(provider: str, **credentials) -> ImageUploader:
if provider == "smms":
return SmMsUploader(credentials["api_key"])
elif provider == "qiniu":
return QiniuUploader(
credentials["access_key"],
credentials["secret_key"]
)
raise ValueError(f"Unknown provider: {provider}")

View File

@@ -1,6 +1,8 @@
from typing import List, Optional, Dict, Any, Literal
from typing import List, Optional, Dict, Any, Literal, Union
from pydantic import BaseModel
from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
class SafetySetting(BaseModel):
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
@@ -13,9 +15,9 @@ class GenerationConfig(BaseModel):
responseSchema: Optional[Dict[str, Any]] = None
candidateCount: Optional[int] = 1
maxOutputTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
topK: Optional[int] = None
temperature: Optional[float] = DEFAULT_TEMPERATURE
topP: Optional[float] = DEFAULT_TOP_P
topK: Optional[int] = DEFAULT_TOP_K
presencePenalty: Optional[float] = None
frequencyPenalty: Optional[float] = None
responseLogprobs: Optional[bool] = None
@@ -33,8 +35,8 @@ class GeminiContent(BaseModel):
class GeminiRequest(BaseModel):
contents: List[GeminiContent]
tools: Optional[List[Dict[str, Any]]] = []
contents: List[GeminiContent] = []
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = None
systemInstruction: Optional[SystemInstruction] = None

View File

@@ -1,17 +1,19 @@
from pydantic import BaseModel
from typing import List, Optional, Union
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
model: str = DEFAULT_MODEL
temperature: Optional[float] = DEFAULT_TEMPERATURE
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
max_tokens: Optional[int] = 8192
max_tokens: Optional[int] = None
top_p: Optional[float] = DEFAULT_TOP_P
top_k: Optional[int] = DEFAULT_TOP_K
stop: Optional[List[str]] = []
top_p: Optional[float] = 0.9
top_k: Optional[int] = 40
class EmbeddingRequest(BaseModel):
@@ -27,4 +29,4 @@ class ImageGenerationRequest(BaseModel):
size: Optional[str] = "1024x1024"
quality: Optional[str] = ""
style: Optional[str] = ""
response_format: Optional[str] = "b64_json"
response_format: Optional[str] = "url"

140
app/exception/exceptions.py Normal file
View File

@@ -0,0 +1,140 @@
"""
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
"""
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.log.logger import get_exceptions_logger
logger = get_exceptions_logger()
class APIError(Exception):
"""API错误基类"""
def __init__(self, status_code: int, detail: str, error_code: str = None):
self.status_code = status_code
self.detail = detail
self.error_code = error_code or "api_error"
super().__init__(self.detail)
class AuthenticationError(APIError):
"""认证错误"""
def __init__(self, detail: str = "Authentication failed"):
super().__init__(
status_code=401, detail=detail, error_code="authentication_error"
)
class AuthorizationError(APIError):
"""授权错误"""
def __init__(self, detail: str = "Not authorized to access this resource"):
super().__init__(
status_code=403, detail=detail, error_code="authorization_error"
)
class ResourceNotFoundError(APIError):
"""资源未找到错误"""
def __init__(self, detail: str = "Resource not found"):
super().__init__(
status_code=404, detail=detail, error_code="resource_not_found"
)
class ModelNotSupportedError(APIError):
"""模型不支持错误"""
def __init__(self, model: str):
super().__init__(
status_code=400,
detail=f"Model {model} is not supported",
error_code="model_not_supported",
)
class APIKeyError(APIError):
"""API密钥错误"""
def __init__(self, detail: str = "Invalid or expired API key"):
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
class ServiceUnavailableError(APIError):
"""服务不可用错误"""
def __init__(self, detail: str = "Service temporarily unavailable"):
super().__init__(
status_code=503, detail=detail, error_code="service_unavailable"
)
def setup_exception_handlers(app: FastAPI) -> None:
"""
设置应用程序的异常处理器
Args:
app: FastAPI应用程序实例
"""
@app.exception_handler(APIError)
async def api_error_handler(request: Request, exc: APIError):
"""处理API错误"""
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
return JSONResponse(
status_code=exc.status_code,
content={"error": {"code": exc.error_code, "message": exc.detail}},
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""处理HTTP异常"""
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
return JSONResponse(
status_code=exc.status_code,
content={"error": {"code": "http_error", "message": exc.detail}},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
"""处理请求验证错误"""
error_details = []
for error in exc.errors():
error_details.append(
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
)
logger.error(f"Validation Error: {error_details}")
return JSONResponse(
status_code=422,
content={
"error": {
"code": "validation_error",
"message": "Request validation failed",
"details": error_details,
}
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理通用异常"""
logger.exception(f"Unhandled Exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={
"error": {
"code": "internal_server_error",
"message": "An unexpected error occurred",
}
},
)

View File

@@ -0,0 +1,174 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
import json
import re
from typing import Any, Dict, List, Optional
import requests
import base64
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
def _get_mime_type_and_data(base64_string):
"""
从 base64 字符串中提取 MIME 类型和数据。
参数:
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
返回:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
mime_type, encoded_data = _get_mime_type_and_data(image_url)
return {
"inline_data": {
"mime_type": mime_type,
"data": encoded_data
}
}
else:
encoded_data = _convert_image_to_base64(image_url)
return {
"inline_data": {
"mime_type": "image/png",
"data": encoded_data
}
}
def _convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
Args:
text: 可能包含图片URL的文本
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(2)
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/png",
"data": base64_data
}
})
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
else:
# 没有图片URL作为纯文本处理
parts.append({"text": text})
return parts
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction_parts = []
for idx, msg in enumerate(messages):
role = msg.get("role", "")
parts = []
# 特别处理最后一个assistant的消息按\n\n分割
if "content" in msg and isinstance(msg["content"], str) and msg["content"] and role == "assistant" and idx == len(messages) - 2:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.extend(_process_text_with_image(msg["content"]))
elif "content" in msg and isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text" and content["text"]:
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
for tool_call in msg["tool_calls"]:
function_call = tool_call.get("function",{})
function_call["args"] = json.loads(function_call.get("arguments","{}"))
del function_call["arguments"]
parts.append({"functionCall": function_call})
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user"
else:
# 如果是最后一条消息,则认为是用户消息
if idx == len(messages) - 1:
role = "user"
else:
role = "model"
if parts:
if role == "system":
system_instruction_parts.extend(parts)
else:
converted_messages.append({"role": role, "parts": parts})
system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction

View File

@@ -1,10 +1,15 @@
# app/services/chat/response_handler.py
import base64
import json
import random
import string
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from typing import Dict, Any, List, Optional
import time
import uuid
from app.core.config import settings
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
class ResponseHandler(ABC):
@@ -29,40 +34,38 @@ class GeminiResponseHandler(ResponseHandler):
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = _extract_text(response, model, stream=True)
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
if not text and not tool_calls:
delta = {}
else:
delta = {"content": text, "role": "assistant"}
if tool_calls:
delta["tool_calls"] = tool_calls
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason
}]
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
}
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = _extract_text(response, model, stream=False)
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
"finish_reason": finish_reason,
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
@@ -127,74 +130,15 @@ def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason
}
def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
text, tool_calls = "", []
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# if "thinking" in model:
# if settings.SHOW_THINKING_PROCESS:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if not parts:
return "", []
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
@@ -209,9 +153,12 @@ def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) ->
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
elif "inlineData" in parts[0]:
text = _extract_image_data(parts[0])
else:
text = ""
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(parts, gemini_format)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
@@ -232,23 +179,93 @@ def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) ->
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = ""
if "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
text += part["text"]
elif "inlineData" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
else:
text = "暂无返回"
return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success:
text = f"\n\n![image]({upload_response.data.url})\n\n"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
"""提取工具调用信息"""
if not parts or not isinstance(parts, list):
return []
letters = string.ascii_lowercase + string.digits
tool_calls = list()
for i in range(len(parts)):
part = parts[i]
if not part or not isinstance(part, dict):
continue
item = part.get("functionCall", {})
if not item or not isinstance(item, dict):
continue
if gemini_format:
tool_calls.append(part)
else:
id = f"call_{''.join(random.sample(letters, 32))}"
name = item.get("name", "")
arguments = json.dumps(item.get("args", None) or {})
tool_calls.append(
{
"index": i,
"id": id,
"type": "function",
"function": {"name": name, "arguments": arguments},
}
)
return tool_calls
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = _extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}], "role": "model"}
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
content = {"parts": [{"text": text}], "role": "model"}
response["candidates"][0]["content"] = content
return response
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = _extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}], "role": "model"}
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
content = {"parts": [{"text": text}], "role": "model"}
response["candidates"][0]["content"] = content
return response

View File

@@ -1,20 +1,19 @@
# app/services/chat/retry_handler.py
from typing import TypeVar, Callable
from functools import wraps
from app.core.logger import get_retry_logger
from app.services.key_manager import KeyManager
from typing import Callable, TypeVar
T = TypeVar('T')
from app.log.logger import get_retry_logger
T = TypeVar("T")
logger = get_retry_logger()
class RetryHandler:
"""重试处理装饰器"""
def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"):
def __init__(self, max_retries: int = 3, key_arg: str = "api_key"):
self.max_retries = max_retries
self.key_manager = key_manager
self.key_arg = key_arg
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
@@ -27,15 +26,21 @@ class RetryHandler:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
logger.warning(
f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}"
)
if self.key_manager:
# 从函数参数中获取 key_manager
key_manager = kwargs.get("key_manager")
if key_manager:
old_key = kwargs.get(self.key_arg)
new_key = await self.key_manager.handle_api_failure(old_key)
new_key = await key_manager.handle_api_failure(old_key)
kwargs[self.key_arg] = new_key
logger.info(f"Switched to new API key: {new_key}")
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
logger.error(
f"All retry attempts failed, raising final exception: {str(last_exception)}"
)
raise last_exception
return wrapper

View File

@@ -0,0 +1,148 @@
# app/services/chat/stream_optimizer.py
import asyncio
import math
from typing import Any, AsyncGenerator, Callable, List
from app.config.config import settings
from app.core.constants import (
DEFAULT_STREAM_CHUNK_SIZE,
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
DEFAULT_STREAM_MAX_DELAY,
DEFAULT_STREAM_MIN_DELAY,
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
)
from app.log.logger import get_gemini_logger, get_openai_logger
logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger()
class StreamOptimizer:
"""流式输出优化器
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
"""
def __init__(
self,
logger=None,
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
):
"""初始化流式输出优化器
参数:
logger: 日志记录器
min_delay: 最小延迟时间(秒)
max_delay: 最大延迟时间(秒)
short_text_threshold: 短文本阈值(字符数)
long_text_threshold: 长文本阈值(字符数)
chunk_size: 长文本分块大小(字符数)
"""
self.logger = logger
self.min_delay = min_delay
self.max_delay = max_delay
self.short_text_threshold = short_text_threshold
self.long_text_threshold = long_text_threshold
self.chunk_size = chunk_size
def calculate_delay(self, text_length: int) -> float:
"""根据文本长度计算延迟时间
参数:
text_length: 文本长度
返回:
延迟时间(秒)
"""
if text_length <= self.short_text_threshold:
# 短文本使用较大延迟
return self.max_delay
elif text_length >= self.long_text_threshold:
# 长文本使用较小延迟
return self.min_delay
else:
# 中等长度文本使用线性插值计算延迟
# 使用对数函数使延迟变化更平滑
ratio = math.log(text_length / self.short_text_threshold) / math.log(
self.long_text_threshold / self.short_text_threshold
)
return self.max_delay - ratio * (self.max_delay - self.min_delay)
def split_text_into_chunks(self, text: str) -> List[str]:
"""将文本分割成小块
参数:
text: 要分割的文本
返回:
文本块列表
"""
return [
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
]
async def optimize_stream_output(
self,
text: str,
create_response_chunk: Callable[[str], Any],
format_chunk: Callable[[Any], str],
) -> AsyncGenerator[str, None]:
"""优化流式输出
参数:
text: 要输出的文本
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
返回:
异步生成器,生成格式化后的响应块
"""
if not text:
return
# 计算智能延迟时间
delay = self.calculate_delay(len(text))
# if self.logger:
# self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
# 根据文本长度决定输出方式
if len(text) >= self.long_text_threshold:
# 长文本:分块输出
chunks = self.split_text_into_chunks(text)
# if self.logger:
# self.logger.info(f"Long text: splitting into {len(chunks)} chunks")
for chunk_text in chunks:
chunk_response = create_response_chunk(chunk_text)
yield format_chunk(chunk_response)
await asyncio.sleep(delay)
else:
# 短文本:逐字符输出
for char in text:
char_chunk = create_response_chunk(char)
yield format_chunk(char_chunk)
await asyncio.sleep(delay)
# 创建默认的优化器实例,可以直接导入使用
openai_optimizer = StreamOptimizer(
logger=logger_openai,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE,
)
gemini_optimizer = StreamOptimizer(
logger=logger_gemini,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE,
)

View File

@@ -133,3 +133,23 @@ def get_retry_logger():
def get_image_create_logger():
return Logger.setup_logger("image_create")
def get_exceptions_logger():
return Logger.setup_logger("exceptions")
def get_application_logger():
return Logger.setup_logger("application")
def get_initialization_logger():
return Logger.setup_logger("initialization")
def get_middleware_logger():
return Logger.setup_logger("middleware")
def get_routes_logger():
return Logger.setup_logger("routes")

View File

@@ -1,44 +1,18 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.core.logger import get_main_logger
"""
应用程序入口模块
"""
from app.api import gemini_routes, openai_routes
import uvicorn
from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.core.application import create_app
from app.log.logger import get_main_logger
# 创建应用程序实例
app = create_app()
# 配置日志
logger = get_main_logger()
app = FastAPI()
# 添加请求日志中间件
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)
# 包含所有路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
@app.get("/health")
@app.get("/")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
logger.info("Starting application server...")
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@@ -0,0 +1,73 @@
"""
中间件配置模块,负责设置和配置应用程序的中间件
"""
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.core.constants import API_VERSION
from app.core.security import verify_auth_token
from app.log.logger import get_middleware_logger
logger = get_middleware_logger()
class AuthMiddleware(BaseHTTPMiddleware):
"""
认证中间件,处理未经身份验证的请求
"""
async def dispatch(self, request: Request, call_next):
# 允许特定路径绕过身份验证
if (
request.url.path not in ["/", "/auth"]
and not request.url.path.startswith("/static")
and not request.url.path.startswith("/gemini")
and not request.url.path.startswith("/v1")
and not request.url.path.startswith(f"/{API_VERSION}")
and not request.url.path.startswith("/health")
and not request.url.path.startswith("/hf")
):
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning(f"Unauthorized access attempt to {request.url.path}")
return RedirectResponse(url="/")
logger.debug("Request authenticated successfully")
response = await call_next(request)
return response
def setup_middlewares(app: FastAPI) -> None:
"""
设置应用程序的中间件
Args:
app: FastAPI应用程序实例
"""
# 添加认证中间件
app.add_middleware(AuthMiddleware)
# 添加请求日志中间件(可选,默认注释掉)
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=[
"GET",
"POST",
"PUT",
"DELETE",
"OPTIONS",
], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)

View File

@@ -1,7 +1,9 @@
import json
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import json
from app.core.logger import get_request_logger
from app.log.logger import get_request_logger
logger = get_request_logger()
@@ -20,7 +22,9 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# 尝试格式化JSON
try:
formatted_body = json.loads(body_str)
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
logger.info(
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
)
except json.JSONDecodeError:
logger.info("Request body is not valid JSON.")
except Exception as e:

178
app/router/gemini_routes.py Normal file
View File

@@ -0,0 +1,178 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.config.config import settings
from app.log.logger import get_gemini_logger
from app.core.security import SecurityService
from app.domain.gemini_models import GeminiContent, GeminiRequest
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.core.constants import API_VERSION
# 路由设置
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
logger = get_gemini_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
async def get_key_manager():
"""获取密钥管理器实例"""
return await get_key_manager_instance()
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
"""获取下一个可用的API密钥"""
return await key_manager.get_next_working_key()
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(
_=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
# 添加搜索模型
if model_service.search_models:
for name in model_service.search_models:
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if model_service.image_models:
for name in model_service.image_models:
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json
@router.post("/models/{model_name}:generateContent")
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/models/{model_name}:streamGenerateContent")
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(max_retries=3, key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Streaming request failed") from e
@router.post("/verify-key/{api_key}")
async def verify_key(api_key: str):
"""验证Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
logger.info("Verifying API key validity")
try:
key_manager = await get_key_manager()
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
# 使用generate_content接口测试key的有效性
gemini_request = GeminiRequest(
contents=[
GeminiContent(
role="user",
parts=[{"text": "hi"}]
)
]
)
response = await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
)
if response:
return JSONResponse({"status": "valid"})
return JSONResponse({"status": "invalid"})
except Exception as e:
logger.error(f"Key verification failed: {str(e)}")
return JSONResponse({"status": "invalid", "error": str(e)})

View File

@@ -1,49 +1,68 @@
from fastapi import HTTPException, APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.core.logger import get_openai_logger
from app.config.config import settings
from app.core.security import SecurityService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
from app.services.chat.retry_handler import RetryHandler
from app.services.embedding_service import EmbeddingService
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.openai_chat_service import OpenAIChatService
from app.domain.openai_models import (
ChatRequest,
EmbeddingRequest,
ImageGenerationRequest,
)
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
router = APIRouter()
logger = get_openai_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService()
async def get_key_manager():
return await get_key_manager_instance()
async def get_next_working_key_wrapper(
key_manager: KeyManager = Depends(get_key_manager),
):
return await key_manager.get_next_working_key()
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(_=Depends(security_service.verify_authorization)):
async def list_models(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "list_models" + "-" * 50)
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
try:
return model_service.get_gemini_openai_models(api_key)
except Exception as e:
logger.error(f"Error getting models list: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
raise HTTPException(
status_code=500, detail="Internal server error while fetching models list"
) from e
@router.post("/v1/chat/completions")
@router.post("/hf/v1/chat/completions")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
@RetryHandler(max_retries=3, key_arg="api_key")
async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager),
):
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
@@ -53,14 +72,23 @@ async def chat_completion(
logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(
status_code=400, detail=f"Model {request.model} is not supported"
)
try:
response = await chat_service.create_image_chat_completion(request=request)
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
response = await chat_service.create_image_chat_completion(request=request)
else:
response = await chat_service.create_chat_completion(request, api_key)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
logger.info("Chat completion request successful")
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@@ -69,8 +97,8 @@ async def chat_completion(
@router.post("/v1/images/generations")
@router.post("/hf/v1/images/generations")
async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
):
logger.info("-" * 50 + "generate_image" + "-" * 50)
logger.info(f"Handling image generation request for prompt: {request.prompt}")
@@ -79,17 +107,19 @@ async def generate_image(
response = image_create_service.generate_images(request)
logger.info("Image generation request successful")
return response
except Exception as e:
logger.error(f"Image generation request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Image generation request failed") from e
raise HTTPException(
status_code=500, detail="Image generation request failed"
) from e
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "embedding" + "-" * 50)
logger.info(f"Handling embedding request for model: {request.model}")
@@ -109,7 +139,8 @@ async def embedding(
@router.get("/v1/keys/list")
@router.get("/hf/v1/keys/list")
async def get_keys_list(
_=Depends(security_service.verify_auth_token),
_=Depends(security_service.verify_auth_token),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取有效和无效的API key列表"""
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
@@ -120,13 +151,12 @@ async def get_keys_list(
"status": "success",
"data": {
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"]
"invalid_keys": keys_status["invalid_keys"],
},
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
}
except Exception as e:
logger.error(f"Error getting keys list: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error while fetching keys list"
status_code=500, detail="Internal server error while fetching keys list"
) from e

114
app/router/routes.py Normal file
View File

@@ -0,0 +1,114 @@
"""
路由配置模块,负责设置和配置应用程序的路由
"""
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from app.core.security import verify_auth_token
from app.log.logger import get_routes_logger
from app.router import gemini_routes, openai_routes
from app.service.key.key_manager import get_key_manager_instance
logger = get_routes_logger()
# 配置Jinja2模板
templates = Jinja2Templates(directory="app/templates")
def setup_routers(app: FastAPI) -> None:
"""
设置应用程序的路由
Args:
app: FastAPI应用程序实例
"""
# 包含API路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
# 添加页面路由
setup_page_routes(app)
# 添加健康检查路由
setup_health_routes(app)
def setup_page_routes(app: FastAPI) -> None:
"""
设置页面相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/", response_class=HTMLResponse)
async def auth_page(request: Request):
"""认证页面"""
return templates.TemplateResponse("auth.html", {"request": request})
@app.post("/auth")
async def authenticate(request: Request):
"""处理认证请求"""
try:
form = await request.form()
auth_token = form.get("auth_token")
if not auth_token:
logger.warning("Authentication attempt with empty token")
return RedirectResponse(url="/", status_code=302)
if verify_auth_token(auth_token):
logger.info("Successful authentication")
response = RedirectResponse(url="/keys", status_code=302)
response.set_cookie(
key="auth_token", value=auth_token, httponly=True, max_age=3600
)
return response
logger.warning("Failed authentication attempt with invalid token")
return RedirectResponse(url="/", status_code=302)
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return RedirectResponse(url="/", status_code=302)
@app.get("/keys", response_class=HTMLResponse)
async def keys_page(request: Request):
"""密钥管理页面"""
try:
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning("Unauthorized access attempt to keys page")
return RedirectResponse(url="/", status_code=302)
key_manager = await get_key_manager_instance()
keys_status = await key_manager.get_keys_by_status()
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
return templates.TemplateResponse(
"keys_status.html",
{
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total,
},
)
except Exception as e:
logger.error(f"Error retrieving keys status: {str(e)}")
raise
def setup_health_routes(app: FastAPI) -> None:
"""
设置健康检查相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/health")
async def health_check(request: Request):
"""健康检查端点"""
logger.info("Health check endpoint called")
return {"status": "healthy"}

View File

@@ -0,0 +1,197 @@
# app/services/chat_service.py
import json
from typing import Any, AsyncGenerator, Dict, List
from app.config.config import settings
from app.domain.gemini_models import GeminiRequest
from app.handler.response_handler import GeminiResponseHandler
from app.handler.stream_optimizer import gemini_optimizer
from app.log.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
logger = get_gemini_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
record = dict()
for item in tools:
if not item or not isinstance(item, dict):
continue
for k, v in item.items():
if k == "functionDeclarations" and v and isinstance(v, list):
functions = record.get("functionDeclarations", [])
functions.extend(v)
record["functionDeclarations"] = functions
else:
record[k] = v
return record
tool = dict()
if payload and isinstance(payload, dict) and "tools" in payload:
if payload.get("tools") and isinstance(payload.get("tools"), dict):
payload["tools"] = [payload.get("tools")]
items = payload.get("tools", [])
if items and isinstance(items, list):
tool.update(_merge_tools(items))
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(payload.get("contents", []))
):
tool["codeExecution"] = {}
if model.endswith("-search"):
tool["googleSearch"] = {}
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
return [tool] if tool else []
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
request_dict = request.model_dump()
if request.generationConfig:
if request.generationConfig.maxOutputTokens is None:
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
request_dict["generationConfig"].pop("maxOutputTokens")
payload = {
"contents": request_dict.get("contents", []),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": request_dict.get("systemInstruction", ""),
}
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
return payload
class GeminiChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
"""从响应中提取文本内容"""
if not response.get("candidates"):
return ""
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if parts and "text" in parts[0]:
return parts[0].get("text", "")
return ""
def _create_char_response(
self, original_response: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的响应"""
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
if response_copy.get("candidates") and response_copy["candidates"][0].get(
"content", {}
).get("parts"):
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
return response_copy
async def generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = _build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
line = line[6:]
response_data = self.response_handler.handle_response(
json.loads(line), model, stream=True
)
text = self._extract_text_from_response(response_data)
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
if text and settings.STREAM_OPTIMIZER_ENABLED:
# 使用流式输出优化器处理文本输出
async for (
optimized_chunk
) in gemini_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_response(response_data, t),
lambda c: "data: " + json.dumps(c) + "\n\n",
):
yield optimized_chunk
else:
# 如果没有文本内容(如工具调用等),整块输出
yield "data: " + json.dumps(response_data) + "\n\n"
logger.info("Streaming completed successfully")
break
except Exception as e:
retries += 1
logger.warning(
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
)
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming. Raising error"
)
break

View File

@@ -0,0 +1,306 @@
# app/services/chat_service.py
import json
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
from app.handler.stream_optimizer import openai_optimizer
from app.log.logger import get_openai_logger
from app.service.client.api_client import GeminiApiClient
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
logger = get_openai_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tool = dict()
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (
model.endswith("-search")
or "-thinking" in model
or model.endswith("-image")
or model.endswith("-image-generation")
)
and not _has_image_parts(messages)
):
tool["codeExecution"] = {}
if model.endswith("-search"):
tool["googleSearch"] = {}
# 将 request 中的 tools 合并到 tools 中
if request.tools:
function_declarations = []
for item in request.tools:
if not item or not isinstance(item, dict):
continue
if item.get("type", "") == "function" and item.get("function"):
function = deepcopy(item.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get("properties", {}):
function.pop("parameters", None)
function_declarations.append(function)
if function_declarations:
# 按照 function 的 name 去重
names, functions = set(), []
for fc in function_declarations:
if fc.get("name") not in names:
names.add(fc.get("name"))
functions.append(fc)
tool["functionDeclarations"] = functions
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
return [tool] if tool else []
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
# if (
# "2.0" in model
# and "gemini-2.0-flash-thinking-exp" not in model
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
def _build_payload(
request: ChatRequest,
messages: List[Dict[str, Any]],
instruction: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""构建请求payload"""
payload = {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k,
},
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
if request.max_tokens is not None:
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if (
instruction
and isinstance(instruction, dict)
and instruction.get("role") == "system"
and instruction.get("parts")
and not request.model.endswith("-image")
and not request.model.endswith("-image-generation")
):
payload["systemInstruction"] = instruction
return payload
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.message_converter = OpenAIMessageConverter()
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
"""从OpenAI响应块中提取文本内容"""
if not chunk.get("choices"):
return ""
choice = chunk["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
return choice["delta"]["content"]
return ""
def _create_char_openai_chunk(
self, original_chunk: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的OpenAI响应块"""
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
chunk_copy["choices"][0]["delta"]["content"] = text
return chunk_copy
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
messages, instruction = self.message_converter.convert(request.messages)
# 构建请求payload
payload = _build_payload(request, messages, instruction)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return await self._handle_normal_completion(request.model, payload, api_key)
async def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(
response, model, stream=False, finish_reason="stop"
)
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = 3
while retries < max_retries:
try:
tool_call_flag = False
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
chunk = json.loads(line[6:])
openai_chunk = self.response_handler.handle_response(
chunk, model, stream=True, finish_reason=None
)
if openai_chunk:
# 提取文本内容
text = self._extract_text_from_openai_chunk(openai_chunk)
if text and settings.STREAM_OPTIMIZER_ENABLED:
# 使用流式输出优化器处理文本输出
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(
openai_chunk, t
),
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
# 如果没有文本内容(如工具调用等),整块输出
if "tool_calls" in json.dumps(openai_chunk):
tool_call_flag = True
yield f"data: {json.dumps(openai_chunk)}\n\n"
if tool_call_flag:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls'))}\n\n"
else:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
break # 成功后退出循环
except Exception as e:
retries += 1
logger.warning(
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
)
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming. Raising error"
)
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"
break
async def create_image_chat_completion(
self,
request: ChatRequest,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
image_generate_request.prompt = request.messages[-1]["content"]
image_res = self.image_create_service.generate_images_chat(
image_generate_request
)
if request.stream:
return self._handle_stream_image_completion(request.model, image_res)
else:
return self._handle_normal_image_completion(request.model, image_res)
async def _handle_stream_image_completion(
self, model: str, image_data: str
) -> AsyncGenerator[str, None]:
if image_data:
openai_chunk = self.response_handler.handle_image_chat_response(
image_data, model, stream=True, finish_reason=None
)
if openai_chunk:
# 提取文本内容
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
# 如果没有文本内容如图片URL等整块输出
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Image chat streaming completed successfully")
def _handle_normal_image_completion(
self, model: str, image_data: str
) -> Dict[str, Any]:
return self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)

View File

@@ -4,6 +4,8 @@ from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod
from app.core.constants import DEFAULT_TIMEOUT
class ApiClient(ABC):
"""API客户端基类"""
@@ -20,17 +22,25 @@ class ApiClient(ABC):
class GeminiApiClient(ApiClient):
"""Gemini API客户端"""
def __init__(self, base_url: str, timeout: int = 300):
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
self.base_url = base_url
self.timeout = timeout
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
def _get_real_model(self, model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
with httpx.Client(timeout=timeout) as client:
if model.endswith("-image"):
model = model[:-6]
return model
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = client.post(url, json=payload)
response = await client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
@@ -38,8 +48,8 @@ class GeminiApiClient(ApiClient):
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response:

View File

@@ -1,9 +1,9 @@
from typing import Union, List
from typing import List, Union
import openai
from openai.types import CreateEmbeddingResponse
from app.core.logger import get_embeddings_logger
from app.log.logger import get_embeddings_logger
logger = get_embeddings_logger()

View File

@@ -0,0 +1,164 @@
import base64
import time
import uuid
from google import genai
from google.genai import types
from app.config.config import settings
from app.core.constants import VALID_IMAGE_RATIOS
from app.domain.openai_models import ImageGenerationRequest
from app.log.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory
logger = get_image_create_logger()
class ImageCreateService:
def __init__(self, aspect_ratio="1:1"):
self.image_model = settings.CREATE_IMAGE_MODEL
self.paid_key = settings.PAID_KEY
self.aspect_ratio = aspect_ratio
def parse_prompt_parameters(self, prompt: str) -> tuple:
"""从prompt中解析参数
支持的格式:
- {n:数量} 例如: {n:2} 生成2张图片
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
"""
import re
# 默认值
n = 1
aspect_ratio = self.aspect_ratio
# 解析n参数
n_match = re.search(r"{n:(\d+)}", prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), "").strip()
# 解析ratio参数
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), "").strip()
return prompt, n, aspect_ratio
def generate_images(self, request: ImageGenerationRequest):
client = genai.Client(api_key=self.paid_key)
if request.size == "1024x1024":
self.aspect_ratio = "1:1"
elif request.size == "1792x1024":
self.aspect_ratio = "16:9"
elif request.size == "1027x1792":
self.aspect_ratio = "9:16"
else:
raise ValueError(
f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
)
# 解析prompt中的参数
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
request.prompt
)
request.prompt = cleaned_prompt
# 如果prompt中指定了n则覆盖请求中的n
if prompt_n > 1:
request.n = prompt_n
# 如果prompt中指定了ratio则覆盖默认的aspect_ratio
if prompt_ratio != self.aspect_ratio:
self.aspect_ratio = prompt_ratio
response = client.models.generate_images(
model=self.image_model,
prompt=request.prompt,
config=types.GenerateImagesConfig(
number_of_images=request.n,
output_mime_type="image/png",
aspect_ratio=self.aspect_ratio,
safety_filter_level="BLOCK_LOW_AND_ABOVE",
person_generation="ALLOW_ADULT",
# language="auto"
),
)
if response.generated_images:
images_data = []
for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes
image_uploader = None
if request.response_format == "b64_json":
base64_image = base64.b64encode(image_data).decode("utf-8")
images_data.append(
{"b64_json": base64_image, "revised_prompt": request.prompt}
)
else:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.SMMS_SECRET_TOKEN,
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY,
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
)
else:
raise ValueError(
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
)
upload_response = image_uploader.upload(image_data, filename)
images_data.append(
{
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt,
}
)
response_data = {
"created": int(time.time()), # Current timestamp
"data": images_data,
}
return response_data
else:
raise Exception("I can't generate these images")
def generate_images_chat(self, request: ImageGenerationRequest) -> str:
response = self.generate_images(request)
image_datas = response["data"]
if image_datas:
markdown_images = []
for index, image_data in enumerate(image_datas):
if "url" in image_data:
markdown_images.append(
f"![Generated Image {index+1}]({image_data['url']})"
)
else:
# 如果是base64格式创建data URL
markdown_images.append(
f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})"
)
return "\n".join(markdown_images)

View File

@@ -1,8 +1,9 @@
import asyncio
from itertools import cycle
from typing import Dict
from app.core.logger import get_key_manager_logger
from app.core.config import settings
from app.config.config import settings
from app.log.logger import get_key_manager_logger
logger = get_key_manager_logger()
@@ -19,7 +20,7 @@ class KeyManager:
async def get_paid_key(self) -> str:
return self.paid_key
async def get_next_key(self) -> str:
"""获取下一个API key"""
async with self.key_cycle_lock:
@@ -61,20 +62,49 @@ class KeyManager:
return await self.get_next_working_key()
def get_fail_count(self, key: str) -> int:
"""获取指定密钥的失败次数"""
return self.key_failure_counts.get(key, 0)
async def get_keys_by_status(self) -> dict:
"""获取分类后的API key列表"""
valid_keys = []
invalid_keys = []
"""获取分类后的API key列表,包括失败次数"""
valid_keys = {}
invalid_keys = {}
async with self.failure_count_lock:
for key in self.api_keys:
masked_key = f"{key}"
if self.key_failure_counts[key] < self.MAX_FAILURES:
valid_keys.append(masked_key)
fail_count = self.key_failure_counts[key]
if fail_count < self.MAX_FAILURES:
valid_keys[key] = fail_count
else:
invalid_keys.append(masked_key)
return {
"valid_keys": valid_keys,
"invalid_keys": invalid_keys
}
invalid_keys[key] = fail_count
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
async def get_first_valid_key(self) -> str:
"""获取第一个有效的API key"""
async with self.failure_count_lock:
for key in self.key_failure_counts:
if self.key_failure_counts[key] < self.MAX_FAILURES:
return key
return self.api_keys[0]
_singleton_instance = None
_singleton_lock = asyncio.Lock()
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
"""
获取 KeyManager 单例实例
如果尚未创建实例将使用提供的 api_keys 初始化 KeyManager
如果已创建实例则忽略 api_keys 参数返回现有单例
"""
global _singleton_instance
async with _singleton_lock:
if _singleton_instance is None:
if api_keys is None:
raise ValueError("API keys are required to initialize the KeyManager")
_singleton_instance = KeyManager(api_keys)
return _singleton_instance

View File

@@ -1,15 +1,20 @@
import requests
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from app.core.logger import get_model_logger
from app.core.config import settings
from typing import Any, Dict, Optional
import requests
from app.config.config import settings
from app.log.logger import get_model_logger
logger = get_model_logger()
class ModelService:
def __init__(self, model_search: list):
self.model_search = model_search
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def __init__(self, search_models: list, image_models: list):
self.search_models = search_models
self.image_models = image_models
self.base_url = settings.BASE_URL
self.filtered_models = settings.FILTERED_MODELS
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
url = f"{self.base_url}/models?key={api_key}"
@@ -18,6 +23,16 @@ class ModelService:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
filtered_models_list = []
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
if model_id not in self.filtered_models:
filtered_models_list.append(model)
else:
logger.info(f"Filtered out model: {model_id}")
gemini_models["models"] = filtered_models_list
return gemini_models
else:
logger.error(f"Error: {response.status_code}")
@@ -36,7 +51,7 @@ class ModelService:
return None
def convert_to_openai_models_format(
self, gemini_models: Dict[str, Any]
self, gemini_models: Dict[str, Any]
) -> Dict[str, Any]:
openai_format = {"object": "list", "data": [], "success": True}
@@ -52,15 +67,32 @@ class ModelService:
"parent": None,
}
openai_format["data"].append(openai_model)
if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model)
if model_id in self.model_search:
if model_id in self.search_models:
search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model)
if model_id in self.image_models:
image_model = openai_model.copy()
image_model["id"] = f"{model_id}-image"
openai_format["data"].append(image_model)
if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model)
return openai_format
def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in self.search_models
if model.endswith("-image"):
model = model[:-6]
return model in self.image_models
return model not in self.filtered_models

View File

@@ -1,53 +0,0 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text":
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
converted_messages.append({"role": role, "parts": parts})
return converted_messages

View File

@@ -1,104 +0,0 @@
# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List
from app.core.logger import get_gemini_logger
from app.services.chat.api_client import GeminiApiClient
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.services.chat.response_handler import GeminiResponseHandler
from app.services.key_manager import KeyManager
logger = get_gemini_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not _has_image_parts(payload.get("contents", [])):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": _build_tools(model, payload),
"safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
}
class GeminiChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = _build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
# print(line)
if line.startswith("data:"):
line = line[6:]
line = json.dumps(self.response_handler.handle_response(json.loads(line), model, stream=True))
yield "data: " + line + "\n\n"
logger.info("Streaming completed successfully")
break
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
break

View File

@@ -1,81 +0,0 @@
import time
import uuid
from google import genai
from google.genai import types
import base64
from app.core.config import settings
from app.core.logger import get_image_create_logger
from app.core.uploader import ImageUploaderFactory
from app.schemas.openai_models import ImageGenerationRequest
logger = get_image_create_logger()
class ImageCreateService:
def __init__(self, aspect_ratio="1:1"):
self.image_model = settings.CREATE_IMAGE_MODEL
self.paid_key = settings.PAID_KEY
self.aspect_ratio = aspect_ratio
def generate_images(self, request: ImageGenerationRequest):
client = genai.Client(api_key=self.paid_key)
if request.size == "1024x1024":
self.aspect_ratio = "1:1"
elif request.size == "1792x1024":
self.aspect_ratio = "16:9"
elif request.size == "1027x1792":
self.aspect_ratio = "9:16"
else:
raise ValueError(
f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
)
response = client.models.generate_images(
model=self.image_model,
prompt=request.prompt,
config=types.GenerateImagesConfig(
number_of_images=request.n,
output_mime_type="image/png",
aspect_ratio=self.aspect_ratio,
safety_filter_level="BLOCK_LOW_AND_ABOVE",
person_generation="ALLOW_ADULT",
# language="auto"
),
)
if response.generated_images:
images_data = []
for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
upload_response = image_uploader.upload(image_data,filename)
# base64_image = base64.b64encode(image_data).decode('utf-8')
images_data.append({
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt
})
response_data = {
"created": int(time.time()), # Current timestamp
"data": images_data
}
return response_data
else:
raise Exception("I can't generate these images")
def generate_images_chat(self, request: ImageGenerationRequest) -> str:
response = self.generate_images(request)
image_datas = response["data"]
if image_datas:
markdown_images = []
for index, image_data in enumerate(image_datas):
markdown_images.append(f"![Generated Image {index+1}]({image_data['url']})")
return "\n".join(markdown_images)

View File

@@ -1,192 +0,0 @@
# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List, Union
from app.core.logger import get_openai_logger
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
from app.core.config import settings
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager
logger = get_openai_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
# if (
# "2.0" in model
# and "gemini-2.0-flash-thinking-exp" not in model
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
def _build_payload(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""构建请求payload"""
return {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k,
},
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
messages = self.message_converter.convert(request.messages)
# 构建请求payload
payload = _build_payload(request, messages)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return self._handle_normal_completion(request.model, payload, api_key)
def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(
response, model, stream=False, finish_reason="stop"
)
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = 3
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
chunk = json.loads(line[6:])
openai_chunk = self.response_handler.handle_response(
chunk, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
break # 成功后退出循环
except Exception as e:
retries += 1
logger.warning(
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
)
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming. Raising error"
)
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"
break
async def create_image_chat_completion(
self,
request: ChatRequest,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
image_generate_request.prompt = request.messages[-1]["content"]
image_res = self.image_create_service.generate_images_chat(image_generate_request)
if request.stream:
return self._handle_stream_image_completion(request.model,image_res)
else:
return self._handle_normal_image_completion(request.model,image_res)
async def _handle_stream_image_completion(
self, model: str, image_data: str
) -> AsyncGenerator[str, None]:
if image_data:
openai_chunk = self.response_handler.handle_image_chat_response(
image_data, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Image chat streaming completed successfully")
def _handle_normal_image_completion(
self, model: str, image_data: str
) -> Dict[str, Any]:
return self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)

249
app/static/css/auth.css Normal file
View File

@@ -0,0 +1,249 @@
body {
font-family: 'Roboto', sans-serif;
line-height: 1.6;
margin: 0;
padding: 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 400px;
width: 90%;
background: rgba(255, 255, 255, 0.95);
padding: 40px;
border-radius: 20px;
box-shadow: 0 15px 35px rgba(0,0,0,0.2);
backdrop-filter: blur(10px);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.container:hover {
transform: translateY(-5px);
box-shadow: 0 20px 40px rgba(0,0,0,0.25);
}
.logo {
text-align: center;
margin-bottom: 30px;
animation: fadeIn 1s ease;
}
.logo i {
font-size: 48px;
color: #764ba2;
margin-bottom: 15px;
}
h2 {
color: #2c3e50;
text-align: center;
margin-bottom: 30px;
font-weight: 700;
font-size: 24px;
animation: slideDown 0.5s ease;
}
form {
display: flex;
flex-direction: column;
gap: 20px;
}
.input-group {
position: relative;
animation: slideUp 0.5s ease;
}
.input-group i {
position: absolute;
left: 12px;
top: 50%;
transform: translateY(-50%);
color: #764ba2;
font-size: 18px;
}
input {
width: 100%;
padding: 12px 12px 12px 40px;
border: 2px solid #e0e0e0;
border-radius: 10px;
font-size: 16px;
box-sizing: border-box;
transition: all 0.3s ease;
background: rgba(255, 255, 255, 0.9);
}
input:focus {
border-color: #764ba2;
box-shadow: 0 0 10px rgba(118, 75, 162, 0.2);
outline: none;
}
button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 14px;
border-radius: 10px;
cursor: pointer;
font-size: 16px;
font-weight: bold;
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(118, 75, 162, 0.3);
}
button:active {
transform: translateY(0);
}
button::after {
content: '';
position: absolute;
top: 50%;
left: 50%;
width: 0;
height: 0;
background: rgba(255, 255, 255, 0.2);
border-radius: 50%;
transform: translate(-50%, -50%);
transition: width 0.6s, height 0.6s;
}
button:active::after {
width: 200px;
height: 200px;
opacity: 0;
}
.error-message {
color: #e74c3c;
margin-top: 15px;
text-align: center;
font-weight: bold;
padding: 10px;
border-radius: 5px;
background: rgba(231, 76, 60, 0.1);
animation: shake 0.5s ease;
}
.copyright {
position: fixed;
bottom: 0;
left: 0;
width: 100%;
background: rgba(255, 255, 255, 0.9);
padding: 10px 0;
text-align: center;
font-size: 14px;
color: #2c3e50;
backdrop-filter: blur(5px);
border-top: 1px solid rgba(0,0,0,0.1);
}
.copyright a {
color: #764ba2;
text-decoration: none;
transition: color 0.3s ease;
}
.copyright a:hover {
color: #667eea;
}
.copyright img {
width: 20px;
height: 20px;
border-radius: 50%;
vertical-align: middle;
margin-right: 5px;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
@keyframes slideDown {
from { transform: translateY(-20px); opacity: 0; }
to { transform: translateY(0); opacity: 1; }
}
@keyframes slideUp {
from { transform: translateY(20px); opacity: 0; }
to { transform: translateY(0); opacity: 1; }
}
@keyframes shake {
0%, 100% { transform: translateX(0); }
25% { transform: translateX(-5px); }
75% { transform: translateX(5px); }
}
@media (max-width: 768px) {
.container {
width: 85%;
padding: 30px;
}
.logo i {
font-size: 40px;
}
h2 {
font-size: 22px;
}
input {
padding: 10px 10px 10px 35px;
font-size: 15px;
}
.input-group i {
font-size: 16px;
}
button {
padding: 12px;
font-size: 15px;
}
}
@media (max-width: 480px) {
.container {
width: 90%;
padding: 25px;
}
.logo i {
font-size: 36px;
}
h2 {
font-size: 20px;
margin-bottom: 25px;
}
form {
gap: 15px;
}
input {
padding: 10px 10px 10px 32px;
font-size: 14px;
}
.input-group i {
font-size: 15px;
left: 10px;
}
button {
padding: 10px;
font-size: 14px;
}
.error-message {
font-size: 14px;
padding: 8px;
margin-top: 12px;
}
}

View File

@@ -0,0 +1,461 @@
body {
font-family: 'Roboto', sans-serif;
line-height: 1.6;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
}
.container {
max-width: 900px;
width: 95%;
background: rgba(255, 255, 255, 0.95);
padding: 40px;
border-radius: 20px;
box-shadow: 0 15px 35px rgba(0,0,0,0.2);
backdrop-filter: blur(10px);
position: relative;
margin: 20px auto;
overflow-y: auto;
max-height: calc(100vh - 40px);
scrollbar-width: none;
-ms-overflow-style: none;
}
.container::-webkit-scrollbar {
display: none;
}
h1 {
color: #2c3e50;
text-align: center;
margin-bottom: 30px;
font-weight: 700;
font-size: 32px;
position: relative;
padding-bottom: 15px;
}
h1::after {
content: '';
position: absolute;
bottom: 0;
left: 50%;
transform: translateX(-50%);
width: 100px;
height: 4px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 2px;
}
.key-list {
margin-bottom: 30px;
background: rgba(248, 249, 250, 0.9);
padding: 25px;
border-radius: 15px;
transition: all 0.3s ease;
border: 1px solid rgba(0,0,0,0.1);
animation: fadeIn 0.5s ease forwards;
}
.key-list:hover {
transform: translateY(-5px);
box-shadow: 0 10px 20px rgba(0,0,0,0.1);
}
.key-list:nth-child(2) {
animation-delay: 0.2s;
}
.key-list h2 {
color: #2c3e50;
margin-bottom: 20px;
display: flex;
justify-content: space-between;
align-items: center;
font-size: 1.5em;
padding-bottom: 10px;
border-bottom: 2px solid rgba(0,0,0,0.1);
cursor: pointer;
}
.key-list h2 .toggle-icon {
margin-right: 10px;
transition: transform 0.3s ease;
}
.key-list h2 .toggle-icon.collapsed {
transform: rotate(-90deg);
}
.key-list .key-content {
transition: all 0.3s ease-out;
overflow: hidden;
height: auto;
opacity: 1;
}
.key-list .key-content.collapsed {
height: 0;
opacity: 0;
padding-top: 0;
padding-bottom: 0;
}
ul {
list-style-type: none;
padding: 0;
margin: 0;
}
li {
background: white;
border: 1px solid rgba(0,0,0,0.1);
margin-bottom: 12px;
padding: 15px;
border-radius: 10px;
transition: all 0.3s ease;
display: flex;
justify-content: space-between;
align-items: center;
box-shadow: 0 2px 5px rgba(0,0,0,0.05);
}
li:hover {
transform: translateX(5px);
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
}
.key-info {
display: flex;
align-items: center;
gap: 15px;
flex: 1;
}
.key-text {
font-family: 'Roboto Mono', monospace;
color: #2c3e50;
}
.fail-count {
background: rgba(231, 76, 60, 0.1);
color: #e74c3c;
padding: 4px 10px;
border-radius: 15px;
font-size: 0.85em;
display: flex;
align-items: center;
gap: 5px;
}
.fail-count i {
font-size: 12px;
}
.key-actions {
display: flex;
gap: 10px;
align-items: center;
}
.verify-btn, .copy-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 8px 15px;
border-radius: 8px;
cursor: pointer;
font-size: 14px;
font-weight: bold;
transition: all 0.3s ease;
display: flex;
align-items: center;
gap: 5px;
}
.verify-btn {
background: linear-gradient(135deg, #2ecc71, #27ae60);
}
.verify-btn:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(46, 204, 113, 0.3);
}
.verify-btn:disabled {
opacity: 0.7;
cursor: not-allowed;
transform: none;
box-shadow: none;
}
.verify-btn i {
font-size: 14px;
}
.copy-btn:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(118, 75, 162, 0.3);
}
.copy-btn:active {
transform: translateY(0);
}
.copy-btn i {
font-size: 14px;
}
.total {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px 25px;
border-radius: 10px;
font-weight: bold;
text-align: center;
font-size: 1.2em;
margin-top: 30px;
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
}
#copyStatus {
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
padding: 15px 30px;
border-radius: 25px;
font-weight: bold;
opacity: 0;
transition: all 0.3s ease;
backdrop-filter: blur(5px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
z-index: 1000;
text-align: center;
min-width: 200px;
color: white;
}
#copyStatus.success {
background: rgba(39, 174, 96, 0.95);
}
#copyStatus.error {
background: rgba(231, 76, 60, 0.95);
}
.status-badge {
padding: 4px 12px;
border-radius: 15px;
font-size: 0.9em;
font-weight: bold;
margin-right: 10px;
}
.status-valid {
background: rgba(39, 174, 96, 0.1);
color: #27ae60;
}
.status-invalid {
background: rgba(231, 76, 60, 0.1);
color: #e74c3c;
}
.scroll-buttons {
position: fixed;
right: 20px;
bottom: 20px;
display: none;
flex-direction: column;
gap: 10px;
z-index: 1000;
}
.scroll-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
width: 40px;
height: 40px;
border: none;
border-radius: 50%;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
font-size: 20px;
transition: all 0.3s ease;
backdrop-filter: blur(5px);
box-shadow: 0 2px 10px rgba(0,0,0,0.2);
}
.scroll-btn:hover {
background: linear-gradient(135deg, #764ba2 0%, #667eea 100%);
transform: scale(1.1);
}
.scroll-btn:active {
transform: scale(0.95);
}
.refresh-btn {
position: fixed;
top: 20px;
right: 20px;
z-index: 1000;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: #fff;
border: none;
padding: 10px 20px;
border-radius: 25px;
cursor: pointer;
font-size: 14px;
font-weight: bold;
transition: all 0.3s ease;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
}
.refresh-btn:hover {
transform: scale(1.05);
box-shadow: 0 8px 20px rgba(118, 75, 162, 0.3);
background: linear-gradient(135deg, #764ba2 0%, #667eea 100%);
}
.refresh-btn:active {
transform: scale(0.95);
}
.refresh-btn i {
transition: transform 0.5s ease;
}
.refresh-btn.loading i {
animation: spin 1s linear infinite;
}
.copyright {
position: fixed;
bottom: 0;
left: 0;
width: 100%;
background: rgba(255, 255, 255, 0.9);
padding: 10px 0;
text-align: center;
font-size: 14px;
color: #2c3e50;
backdrop-filter: blur(5px);
border-top: 1px solid rgba(0,0,0,0.1);
}
.copyright a {
color: #764ba2;
text-decoration: none;
transition: color 0.3s ease;
}
.copyright a:hover {
color: #667eea;
}
.copyright img {
width: 20px;
height: 20px;
border-radius: 50%;
vertical-align: middle;
margin-right: 5px;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes spin {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
@media (max-width: 768px) {
.container {
width: 100%;
padding: 20px;
margin: 10px auto;
}
body {
padding: 10px;
}
h1 {
font-size: 24px;
}
.key-list h2 {
font-size: 1.2em;
flex-direction: column;
gap: 10px;
align-items: flex-start;
}
.key-info {
flex-direction: column;
align-items: flex-start;
gap: 8px;
}
li {
flex-direction: column;
gap: 10px;
}
.key-actions {
width: 100%;
flex-direction: column;
}
.verify-btn, .copy-btn {
width: 100%;
justify-content: center;
}
.key-text {
word-break: break-all;
}
.scroll-buttons {
right: 10px;
bottom: 10px;
}
.scroll-btn {
width: 35px;
height: 35px;
font-size: 16px;
}
.refresh-btn {
top: 10px;
right: 10px;
padding: 8px 16px;
font-size: 12px;
}
}
@media (max-width: 480px) {
.container {
padding: 15px;
}
h1 {
font-size: 20px;
}
.key-list {
padding: 15px;
}
.status-badge {
padding: 3px 8px;
font-size: 0.8em;
}
.fail-count {
font-size: 0.8em;
}
.total {
font-size: 1em;
padding: 12px 20px;
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

18
app/static/js/auth.js Normal file
View File

@@ -0,0 +1,18 @@
if ('serviceWorker' in navigator) {
window.addEventListener('load', () => {
navigator.serviceWorker.register('/static/service-worker.js')
.then(registration => {
console.log('ServiceWorker注册成功:', registration.scope);
})
.catch(error => {
console.log('ServiceWorker注册失败:', error);
});
});
}
document.addEventListener('DOMContentLoaded', () => {
const copyrightYear = document.querySelector('.copyright script');
if (copyrightYear) {
copyrightYear.textContent = new Date().getFullYear();
}
});

View File

@@ -0,0 +1,175 @@
function copyToClipboard(text) {
if (navigator.clipboard && navigator.clipboard.writeText) {
return navigator.clipboard.writeText(text);
} else {
return new Promise((resolve, reject) => {
const textArea = document.createElement("textarea");
textArea.value = text;
textArea.style.position = "fixed";
document.body.appendChild(textArea);
textArea.focus();
textArea.select();
try {
const successful = document.execCommand('copy');
document.body.removeChild(textArea);
if (successful) {
resolve();
} else {
reject(new Error('复制失败'));
}
} catch (err) {
document.body.removeChild(textArea);
reject(err);
}
});
}
}
function copyKeys(type) {
const keys = Array.from(document.querySelectorAll(`#${type}Keys .key-text`)).map(span => span.textContent.trim());
const jsonKeys = JSON.stringify(keys);
copyToClipboard(jsonKeys)
.then(() => {
showCopyStatus(`已成功复制${type === 'valid' ? '有效' : '无效'}密钥到剪贴板`);
})
.catch((err) => {
console.error('无法复制文本: ', err);
showCopyStatus('复制失败,请重试');
});
}
function copyKey(key) {
copyToClipboard(key)
.then(() => {
showCopyStatus(`已成功复制密钥到剪贴板`);
})
.catch((err) => {
console.error('无法复制文本: ', err);
showCopyStatus('复制失败,请重试');
});
}
function showCopyStatus(message, type = 'success') {
const statusElement = document.getElementById('copyStatus');
statusElement.textContent = message;
statusElement.className = type; // 设置样式类
statusElement.style.opacity = 1;
setTimeout(() => {
statusElement.style.opacity = 0;
setTimeout(() => {
statusElement.className = ''; // 清除样式类
}, 300);
}, 2000);
}
async function verifyKey(key, button) {
try {
// 禁用按钮并显示加载状态
button.disabled = true;
const originalHtml = button.innerHTML;
button.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 验证中';
const response = await fetch(`/gemini/v1beta/verify-key/${key}`, {
method: 'POST'
});
const data = await response.json();
// 根据验证结果更新UI
if (data.status === 'valid') {
showCopyStatus('密钥验证成功', 'success');
button.style.backgroundColor = '#27ae60';
} else {
showCopyStatus('密钥验证失败', 'error');
button.style.backgroundColor = '#e74c3c';
}
// 3秒后恢复按钮原始状态
setTimeout(() => {
button.innerHTML = originalHtml;
button.disabled = false;
button.style.backgroundColor = '';
}, 3000);
} catch (error) {
console.error('验证失败:', error);
showCopyStatus('验证请求失败', 'error');
button.disabled = false;
button.innerHTML = '<i class="fas fa-check-circle"></i> 验证';
}
}
function scrollToTop() {
const container = document.querySelector('.container');
container.scrollTo({
top: 0,
behavior: 'smooth'
});
}
function scrollToBottom() {
const container = document.querySelector('.container');
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
});
}
function updateScrollButtons() {
const container = document.querySelector('.container');
const scrollButtons = document.querySelector('.scroll-buttons');
if (container.scrollHeight > container.clientHeight) {
scrollButtons.style.display = 'flex';
} else {
scrollButtons.style.display = 'none';
}
}
function refreshPage(button) {
button.classList.add('loading');
button.disabled = true;
setTimeout(() => {
window.location.reload();
}, 300);
}
function toggleSection(header, sectionId) {
const toggleIcon = header.querySelector('.toggle-icon');
const content = header.nextElementSibling;
toggleIcon.classList.toggle('collapsed');
content.classList.toggle('collapsed');
}
// 初始化
document.addEventListener('DOMContentLoaded', () => {
// 检查滚动按钮
updateScrollButtons();
// 监听展开/折叠事件
document.querySelectorAll('.key-list h2').forEach(header => {
header.addEventListener('click', () => {
setTimeout(updateScrollButtons, 300);
});
});
// 更新版权年份
const copyrightYear = document.querySelector('.copyright script');
if (copyrightYear) {
copyrightYear.textContent = new Date().getFullYear();
}
});
// Service Worker registration
if ('serviceWorker' in navigator) {
window.addEventListener('load', () => {
navigator.serviceWorker.register('/static/service-worker.js')
.then(registration => {
console.log('ServiceWorker注册成功:', registration.scope);
})
.catch(error => {
console.log('ServiceWorker注册失败:', error);
});
});
}

17
app/static/manifest.json Normal file
View File

@@ -0,0 +1,17 @@
{
"name": "Gemini Balance",
"short_name": "GBalance",
"description": "Gemini API密钥管理工具",
"start_url": "/",
"display": "standalone",
"background_color": "#667eea",
"theme_color": "#764ba2",
"icons": [
{
"src": "/static/icons/icon-192x192.png",
"sizes": "192x192",
"type": "image/png",
"purpose": "any maskable"
}
]
}

View File

@@ -0,0 +1,43 @@
const CACHE_NAME = 'gbalance-cache-v1';
const urlsToCache = [
'/',
'/static/manifest.json',
'/static/icons/icon-192x192.png'
];
self.addEventListener('install', event => {
event.waitUntil(
caches.open(CACHE_NAME)
.then(cache => {
console.log('Opened cache');
return cache.addAll(urlsToCache);
})
);
});
self.addEventListener('fetch', event => {
event.respondWith(
caches.match(event.request)
.then(response => {
if (response) {
return response;
}
return fetch(event.request);
})
);
});
self.addEventListener('activate', event => {
const cacheWhitelist = [CACHE_NAME];
event.waitUntil(
caches.keys().then(cacheNames => {
return Promise.all(
cacheNames.map(cacheName => {
if (cacheWhitelist.indexOf(cacheName) === -1) {
return caches.delete(cacheName);
}
})
);
})
);
});

42
app/templates/auth.html Normal file
View File

@@ -0,0 +1,42 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>验证页面</title>
<link rel="manifest" href="/static/manifest.json">
<meta name="theme-color" content="#764ba2">
<meta name="apple-mobile-web-app-capable" content="yes">
<meta name="apple-mobile-web-app-status-bar-style" content="black">
<meta name="apple-mobile-web-app-title" content="GBalance">
<link rel="icon" href="/static/icons/icon-192x192.png">
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
<link rel="stylesheet" href="/static/css/auth.css">
</head>
<body>
<div class="container">
<div class="logo">
<i class="fas fa-shield-alt"></i>
</div>
<h2>安全验证</h2>
<form id="auth-form" action="/auth" method="post">
<div class="input-group">
<i class="fas fa-key"></i>
<input type="password" id="auth-token" name="auth_token" required placeholder="请输入验证令牌">
</div>
<button type="submit">
验证访问
</button>
</form>
{% if error %}
<p class="error-message">{{ error }}</p>
{% endif %}
</div>
<div class="copyright">
© <script>document.write(new Date().getFullYear())</script> by <a href="https://linux.do/u/snaily" target="_blank"><img src="https://linux.do/user_avatar/linux.do/snaily/288/306510_2.gif" alt="snaily">snaily</a> |
<a href="https://github.com/snailyp/gemini-balance" target="_blank"><i class="fab fa-github"></i> GitHub</a>
</div>
<script src="/static/js/auth.js"></script>
</body>
</html>

View File

@@ -0,0 +1,128 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>API密钥状态</title>
<link rel="manifest" href="/static/manifest.json">
<meta name="theme-color" content="#764ba2">
<meta name="apple-mobile-web-app-capable" content="yes">
<meta name="apple-mobile-web-app-status-bar-style" content="black">
<meta name="apple-mobile-web-app-title" content="GBalance">
<link rel="icon" href="/static/icons/icon-192x192.png">
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
<link rel="stylesheet" href="/static/css/keys_status.css">
</head>
<body>
<div class="container">
<button class="refresh-btn" onclick="refreshPage(this)">
<i class="fas fa-sync-alt"></i>
</button>
<h1>API密钥状态</h1>
<div class="key-list">
<h2 onclick="toggleSection(this, 'validKeys')">
<span>
<i class="fas fa-chevron-down toggle-icon"></i>
<i class="fas fa-check-circle" style="color: #27ae60;"></i>
有效密钥
</span>
<button class="copy-btn" onclick="event.stopPropagation(); copyKeys('valid')">
<i class="fas fa-copy"></i>
批量复制
</button>
</h2>
<div class="key-content">
<ul id="validKeys">
{% for key, fail_count in valid_keys.items() %}
<li>
<div class="key-info">
<span class="status-badge status-valid">
<i class="fas fa-check"></i> 有效
</span>
<span class="key-text">{{ key }}</span>
<span class="fail-count">
<i class="fas fa-exclamation-triangle"></i>
失败: {{ fail_count }}
</span>
</div>
<div class="key-actions">
<button class="verify-btn" onclick="verifyKey('{{ key }}', this)">
<i class="fas fa-check-circle"></i>
验证
</button>
<button class="copy-btn" onclick="copyKey('{{ key }}')">
<i class="fas fa-copy"></i>
复制
</button>
</div>
</li>
{% endfor %}
</ul>
</div>
</div>
<div class="key-list">
<h2 onclick="toggleSection(this, 'invalidKeys')">
<span>
<i class="fas fa-chevron-down toggle-icon"></i>
<i class="fas fa-times-circle" style="color: #e74c3c;"></i>
无效密钥
</span>
<button class="copy-btn" onclick="event.stopPropagation(); copyKeys('invalid')">
<i class="fas fa-copy"></i>
批量复制
</button>
</h2>
<div class="key-content">
<ul id="invalidKeys">
{% for key, fail_count in invalid_keys.items() %}
<li>
<div class="key-info">
<span class="status-badge status-invalid">
<i class="fas fa-times"></i> 无效
</span>
<span class="key-text">{{ key }}</span>
<span class="fail-count">
<i class="fas fa-exclamation-triangle"></i>
失败: {{ fail_count }}
</span>
</div>
<div class="key-actions">
<button class="verify-btn" onclick="verifyKey('{{ key }}', this)">
<i class="fas fa-check-circle"></i>
验证
</button>
<button class="copy-btn" onclick="copyKey('{{ key }}')">
<i class="fas fa-copy"></i>
复制
</button>
</div>
</li>
{% endfor %}
</ul>
</div>
</div>
<div class="total">
<i class="fas fa-key"></i> 总密钥数:{{ total }}
</div>
</div>
<div class="scroll-buttons">
<button class="scroll-btn" onclick="scrollToTop()" title="回到顶部">
<i class="fas fa-chevron-up"></i>
</button>
<button class="scroll-btn" onclick="scrollToBottom()" title="滚动到底部">
<i class="fas fa-chevron-down"></i>
</button>
</div>
<div id="copyStatus"></div>
<div class="copyright">
© <script>document.write(new Date().getFullYear())</script> by <a href="https://linux.do/u/snaily" target="_blank"><img src="https://linux.do/user_avatar/linux.do/snaily/288/306510_2.gif" alt="snaily">snaily</a> |
<a href="https://github.com/snailyp/gemini-balance" target="_blank"><i class="fab fa-github"></i> GitHub</a>
</div>
<script src="/static/js/keys_status.js"></script>
</body>
</html>

3
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
工具包初始化模块
"""

146
app/utils/helpers.py Normal file
View File

@@ -0,0 +1,146 @@
"""
通用工具函数模块
"""
import json
import re
import base64
import requests
from typing import Dict, Any, List, Optional, Tuple
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
"""
从 base64 字符串中提取 MIME 类型和数据
Args:
base64_string: 可能包含 MIME 类型信息的 base64 字符串
Returns:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
Raises:
Exception: 如果获取图片失败
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
"""
格式化JSON响应
Args:
data: 要格式化的数据
indent: 缩进空格数
Returns:
str: 格式化后的JSON字符串
"""
return json.dumps(data, indent=indent, ensure_ascii=False)
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
"""
从prompt中解析参数
支持的格式:
- {n:数量} 例如: {n:2} 生成2张图片
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
Args:
prompt: 提示文本
default_ratio: 默认比例
Returns:
tuple: (清理后的提示文本, 图片数量, 比例)
"""
# 默认值
n = 1
aspect_ratio = default_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
return prompt, n, aspect_ratio
def extract_image_urls_from_markdown(text: str) -> List[str]:
"""
从Markdown文本中提取图片URL
Args:
text: Markdown文本
Returns:
List[str]: 图片URL列表
"""
pattern = IMAGE_URL_PATTERN
matches = re.findall(pattern, text)
return [match[1] for match in matches]
def is_valid_api_key(key: str) -> bool:
"""
检查API密钥格式是否有效
Args:
key: API密钥
Returns:
bool: 如果密钥格式有效则返回True
"""
# 检查Gemini API密钥格式
if key.startswith('AIza'):
return len(key) >= 30
# 检查OpenAI API密钥格式
if key.startswith('sk-'):
return len(key) >= 30
return False

393
app/utils/uploader.py Normal file
View File

@@ -0,0 +1,393 @@
import requests
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
from enum import Enum
from typing import Optional, Any
class UploadErrorType(Enum):
"""上传错误类型枚举"""
NETWORK_ERROR = "network_error" # 网络请求错误
AUTH_ERROR = "auth_error" # 认证错误
INVALID_FILE = "invalid_file" # 无效文件
SERVER_ERROR = "server_error" # 服务器错误
PARSE_ERROR = "parse_error" # 响应解析错误
UNKNOWN = "unknown" # 未知错误
class UploadError(Exception):
"""图片上传错误异常类"""
def __init__(
self,
message: str,
error_type: UploadErrorType = UploadErrorType.UNKNOWN,
status_code: Optional[int] = None,
details: Optional[dict] = None,
original_error: Optional[Exception] = None
):
"""
初始化上传错误异常
Args:
message: 错误消息
error_type: 错误类型
status_code: HTTP状态码
details: 详细错误信息
original_error: 原始异常
"""
self.message = message
self.error_type = error_type
self.status_code = status_code
self.details = details or {}
self.original_error = original_error
# 构建完整错误信息
full_message = f"[{error_type.value}] {message}"
if status_code:
full_message = f"{full_message} (Status: {status_code})"
if details:
full_message = f"{full_message} - Details: {details}"
super().__init__(full_message)
@classmethod
def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError":
"""
从HTTP响应创建错误实例
Args:
response: HTTP响应对象
message: 自定义错误消息
"""
try:
error_data = response.json()
details = error_data.get("data", {})
return cls(
message=message or error_data.get("message", "Unknown error"),
error_type=UploadErrorType.SERVER_ERROR,
status_code=response.status_code,
details=details
)
except Exception:
return cls(
message=message or "Failed to parse error response",
error_type=UploadErrorType.PARSE_ERROR,
status_code=response.status_code
)
class SmMsUploader(ImageUploader):
API_URL = "https://sm.ms/api/v2/upload"
def __init__(self, api_key: str):
self.api_key = api_key
def upload(self, file: bytes, filename: str) -> UploadResponse:
try:
# 准备请求头
headers = {
"Authorization": f"Basic {self.api_key}"
}
# 准备文件数据
files = {
"smfile": (filename, file, "image/png")
}
# 发送请求
response = requests.post(
self.API_URL,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if not result.get("success"):
raise UploadError(result.get("message", "Upload failed"))
# 转换为统一格式
data = result["data"]
image_metadata = ImageMetadata(
width=data["width"],
height=data["height"],
filename=data["filename"],
size=data["size"],
url=data["url"],
delete_url=data["delete"]
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(f"Upload request failed: {str(e)}")
except (KeyError, ValueError) as e:
# 处理响应解析错误
raise UploadError(f"Invalid response format: {str(e)}")
except Exception as e:
# 处理其他未预期的错误
raise UploadError(f"Upload failed: {str(e)}")
class QiniuUploader(ImageUploader):
def __init__(self, access_key: str, secret_key: str):
self.access_key = access_key
self.secret_key = secret_key
def upload(self, file: bytes, filename: str) -> UploadResponse:
# 实现七牛云的具体上传逻辑
pass
class PicGoUploader(ImageUploader):
"""Chevereto API 图片上传器"""
def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"):
"""
初始化 Chevereto 上传器
Args:
api_key: Chevereto API 密钥
api_url: Chevereto API 上传地址
"""
self.api_key = api_key
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到 Chevereto 服务
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求头
headers = {
"X-API-Key": self.api_key
}
# 准备文件数据
files = {
"source": (filename, file)
}
# 发送请求
response = requests.post(
self.api_url,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if result.get("status_code") != 200:
error_message = "Upload failed"
if "error" in result:
error_message = result["error"].get("message", error_message)
raise UploadError(
message=error_message,
error_type=UploadErrorType.SERVER_ERROR,
status_code=result.get("status_code"),
details=result.get("error")
)
# 从响应中提取图片信息
image_data = result.get("image", {})
# 构建图片元数据
image_metadata = ImageMetadata(
width=image_data.get("width", 0),
height=image_data.get("height", 0),
filename=image_data.get("filename", filename),
size=image_data.get("size", 0),
url=image_data.get("url", ""),
delete_url=image_data.get("delete_url", None)
)
return UploadResponse(
success=True,
code="success",
message=result.get("success", {}).get("message", "Upload success"),
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class CloudFlareImgBedUploader(ImageUploader):
"""CloudFlare图床上传器"""
def __init__(self, auth_code: str, api_url: str):
"""
初始化CloudFlare图床上传器
Args:
auth_code: 认证码
api_url: 上传API地址
"""
self.auth_code = auth_code
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到CloudFlare图床
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求URL添加认证码参数如果存在
if self.auth_code:
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
else:
request_url = f"{self.api_url}?uploadNameType=origin"
# 准备文件数据
files = {
"file": (filename, file)
}
# 发送请求
response = requests.post(
request_url,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证响应格式
if not result or not isinstance(result, list) or len(result) == 0:
raise UploadError(
message="Invalid response format",
error_type=UploadErrorType.PARSE_ERROR
)
# 获取文件URL
file_path = result[0].get("src")
if not file_path:
raise UploadError(
message="Missing file URL in response",
error_type=UploadErrorType.PARSE_ERROR
)
# 构建完整URL如果返回的是相对路径
base_url = self.api_url.split("/upload")[0]
full_url = file_path if file_path.startswith(("http://", "https://")) else f"{base_url}{file_path}"
# 构建图片元数据注意CloudFlare-ImgBed不返回所有元数据所以部分字段为默认值
image_metadata = ImageMetadata(
width=0, # CloudFlare-ImgBed不返回宽度
height=0, # CloudFlare-ImgBed不返回高度
filename=filename,
size=0, # CloudFlare-ImgBed不返回大小
url=full_url,
delete_url=None # CloudFlare-ImgBed不返回删除URL
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError, IndexError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class ImageUploaderFactory:
@staticmethod
def create(provider: str, **credentials) -> ImageUploader:
if provider == "smms":
return SmMsUploader(credentials["api_key"])
elif provider == "qiniu":
return QiniuUploader(
credentials["access_key"],
credentials["secret_key"]
)
elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
return PicGoUploader(credentials["api_key"], api_url)
elif provider == "cloudflare_imgbed":
return CloudFlareImgBedUploader(
credentials["auth_code"],
credentials["base_url"]
)
raise ValueError(f"Unknown provider: {provider}")

9
docker-compose.yml Normal file
View File

@@ -0,0 +1,9 @@
version: '3'
services:
gemini-balance:
build: .
ports:
- "8000:8000"
env_file:
- .env

View File

@@ -6,4 +6,6 @@ pydantic_settings
requests
starlette
uvicorn
google-genai
google-genai
jinja2
python-multipart