Compare commits

..

24 Commits

Author SHA1 Message Date
snaily
cc36ba4c9e feat(config): 新增流式输出优化器开关配置
在环境变量示例文件(.env.example)和配置类(config.py)中新增 STREAM_OPTIMIZER_ENABLED 配置项,用于控制流式输出优化器的启用状态,默认设为 false

调整 Gemini 和 OpenAI 聊天服务的流式响应处理逻辑:
- 仅在流式优化器启用时(settings.STREAM_OPTIMIZER_ENABLED 为 true)
- 才会对文本内容执行流式输出优化处理
- 保持原有文本提取逻辑不变,仅增加配置条件判断

该变更使流式输出优化器变为可选功能,方便根据实际需求进行开关控制
2025-04-03 04:47:06 +08:00
snaily
baf643e884 feat: 新增请求超时配置及优化模型列表接口api_key获取方式
1. 新增功能:
   - 在`.env.example`中添加`TIME_OUT=300`配置项(包含中文注释)
   - 在`Settings`类中增加`TIME_OUT`字段(读取自`DEFAULT_TIMEOUT`)

2. 优化内容:
   - 生成配置:
     * 为`GenerationConfig`设置默认温度/TOP_P/TOP_K值
     * 移除`maxOutputTokens`默认值,改为可选传递
   - OpenAI请求:
     * 移除`max_tokens`默认值
     * 只有当`max_tokens`有值时才添加到请求payload
   - 日志优化:
     * 注释掉`stream_optimizer.py`中部分调试日志

3. 模型列表接口api_key获取方式
2025-04-03 03:12:59 +08:00
严浩
360bc9e48d feat(ci): 更新Docker发布工作流 2025-04-02 13:49:05 +08:00
snaily
c0a27d0542 Update README.md 2025-03-29 01:03:36 +08:00
snaily
84052a2179 feat(auth): 增强Gemini API的认证机制支持URL参数
- 将generate_content和stream_generate_content端点的认证依赖从verify_goog_api_key更改为verify_key_or_goog_api_key
- 使Gemini API同时支持URL参数中的key和请求头中的x-goog-api-key进行认证
- 提高API的灵活性,便于不同客户端集成
2025-03-28 23:44:40 +08:00
snaily
2e7ecd88b5 feat: 增强Gemini API tools参数处理
- 修改GeminiRequest模型,使tools字段支持单个工具对象或工具对象列表
- 在gemini_chat_service中添加类型转换逻辑,确保tools始终以列表形式处理
- 提高API的灵活性和兼容性
2025-03-28 20:50:01 +08:00
snaily
0b1f3dfc04 feat(auth): 支持x-goog-api-key请求头认证
- 添加verify_key_or_goog_api_key方法,支持同时验证URL参数中的key和请求头中的x-goog-api-key
- 更新models接口使用新的认证方法,提高与Google API客户端的兼容性
2025-03-28 19:27:42 +08:00
snaily
c691c7c1cf fix:当没有可用工具时返回空列表而非包含空字典的列表
在_build_tools函数中,当没有工具配置可用时(即tool为空字典),现在会返回空列表[]而不是[{}]。这个防御性编程修复可以避免向Gemini API发送无效的工具配置,防止可能的API调用错误。
2025-03-25 15:18:27 +08:00
snaily
97db7eebf1 chore:修改图片处理逻辑,统一使用base64编码
将_convert_image函数中对非data:image格式URL的处理方式从直接返回URL改为转换为base64编码的内联数据。这样无论图片是以data URI形式还是URL形式提供,都会统一转换为base64编码,确保与API交互时图片数据格式的一致性。
2025-03-25 13:23:17 +08:00
snaily
60dca70fcd fix: 改进图片显示和移除调试输出
优化图片链接格式,在图片前后添加空行以改善显示效果
注释掉OpenAI聊天服务中的调试打印语句
2025-03-22 03:38:45 +08:00
snaily
89b9f7919a feat: 添加对OpenAI工具调用功能的支持
改进消息转换器以处理OpenAI的tool_calls格式
添加JSON解析以正确转换函数调用参数
优化消息处理逻辑,增加更多空值检查
在流式响应中添加工具调用检测和处理
根据工具调用状态设置适当的finish_reason
2025-03-22 02:48:25 +08:00
Toddy
a8dc98ab6a fix tool use with function calling is unsupported error 2025-03-21 05:04:53 +00:00
snaily
b3a057b6ba refactor: 代码结构优化与常量化
将日志系统从 app/logger/ 移至 app/log/ 目录
将路由配置从 routers.py 重命名为 routes.py
将硬编码配置值移至 constants.py 中的默认常量
统一代码格式和导入排序
优化函数参数对齐方式
2025-03-20 21:59:18 +08:00
snaily
b14bb93d8f refactor: 项目结构优化与FastAPI生命周期更新
重构项目目录结构,提高代码组织性和可维护性

将schemas目录重命名为domain,更好地表达领域模型概念
将services目录细分为service/chat、service/image等子目录
将api目录重命名为router,更符合FastAPI惯例
创建utils目录存放通用工具函数
更新FastAPI应用程序生命周期管理

替换已弃用的on_event方法为推荐的lifespan事件处理器
添加应用程序关闭时的日志记录
代码质量改进

抽取常量到constants.py,减少硬编码值
添加helpers.py提供通用工具函数
优化配置管理,使用环境变量和默认值
完善文档字符串,提高代码可读性
2025-03-20 17:13:03 +08:00
snaily
8ca62707ea feat: 添加搜索模型配置并改进Markdown链接处理
在Dockerfile中添加SEARCH_MODELS环境变量,支持gemini-2.0-flash-exp和gemini-2.0-pro-exp模型
改进message_converter中的图片链接正则表达式
2025-03-19 19:56:50 +08:00
Toddy
21444ed6c7 chore: 统一从model_service读取模型列表 2025-03-18 18:05:00 +00:00
Toddy
ba292dbedd chore: 规范变量名 2025-03-18 17:54:18 +00:00
snaily
6ba58ce9d1 fix: 重构图片MIME类型转换逻辑
将"image/jpg"到"image/jpeg"的MIME类型转换逻辑从_convert_image函数移至_get_mime_type_and_data函数,避免代码重复并提高一致性。这确保了MIME类型的标准化处理发生在数据提取的同一位置。
2025-03-18 21:50:27 +08:00
snaily
16f16a3ae9 Merge branch 'pr/yangtb2024/13' 2025-03-18 21:46:34 +08:00
snaily
26dcb64687 fix: 将image/jpg MIME类型转换为标准的image/jpeg
修复了图像转换过程中的MIME类型处理,确保当遇到非标准的"image/jpg"类型时,将其转换为标准的"image/jpeg"类型。这样可以提高与接收图像数据的API和系统的兼容性
2025-03-18 21:35:19 +08:00
yangtb2024
df88492113 将chat-bison-001、text-bison-001和embedding-gecko-001添加到FILTERED_MODELS列表 2025-03-18 15:21:29 +08:00
yangtb2024
851bb9c09b 将 filtered_models 从硬编码改为可配置参数
1. 在 config.py 中添加 FILTERED_MODELS 配置项
2. 在 .env.example 中添加 FILTERED_MODELS 示例
3. 修改 model_service.py 以使用配置的过滤模型列表
4. 优化模型过滤逻辑
2025-03-18 14:47:58 +08:00
yangtb2024
0cac178572 Merge branch 'snailyp:main' into model 2025-03-18 12:44:09 +08:00
yangtb2024
016e6e06ee Filter out vision-based Gemini models from model list 2025-03-17 13:56:01 +08:00
35 changed files with 1309 additions and 529 deletions

View File

@@ -1,13 +1,17 @@
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
MODEL_IMAGE=["gemini-2.0-flash-exp"]
TEST_MODEL=gemini-1.5-flash
IMAGE_MODELS=["gemini-2.0-flash-exp"]
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta
MAX_FAILURES=10
# 请求超时时间(秒)
TIME_OUT=300
#########################image_generate 相关配置###########################
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
@@ -18,6 +22,7 @@ CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_OPTIMIZER_ENABLED=false
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10

View File

@@ -2,8 +2,6 @@ name: Docker Image CI
on:
push:
# branches: [ "main" ]
tags: [ 'v*.*.*' ]
pull_request:
branches: [ "main" ]
@@ -43,20 +41,30 @@ jobs:
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=raw,value=latest,enable={{is_default_branch}}
# https://github.com/docker/metadata-action/tree/v5/?tab=readme-ov-file#semver
# Event: push, Ref: refs/head/main, Tags: main
# Event: push tag, Ref: refs/tags/v1.2.3, Tags: 1.2.3, 1.2, 1, latest
# Event: push tag, Ref: refs/tags/v2.0.8-rc1, Tags: 2.0.8-rc1
type=ref,event=branch
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,format=long
type=semver,pattern={{major}}
labels: |
org.opencontainers.image.description=OpenAI API Compatible Server
org.opencontainers.image.source=${{ github.event.repository.html_url }}
- name: Build and push Docker image
uses: docker/build-push-action@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Build and push
uses: docker/build-push-action@v6
with:
file: Dockerfile
context: .
platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }}
load: false
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=${{ github.workflow }}
cache-to: type=gha,scope=${{ github.workflow }}

View File

@@ -11,7 +11,8 @@ ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
# Expose port
EXPOSE 8000

View File

@@ -1,4 +1,4 @@
# 🚀 FastAPI OpenAI (Gemini) 代理服务
# 🚀 Gemini 代理服务支持openai/gemini格式
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@@ -64,10 +64,13 @@
AUTH_TOKEN="" # 超级管理员token具有所有权限默认使用 ALLOWED_TOKENS 的第一个
# 模型功能配置
MODEL_SEARCH=["gemini-2.0-flash-exp"] # 支持搜索功能的模型列表
TEST_MODEL="gemini-1.5-flash" # 用于测试密钥是否可用的模型
SEARCH_MODELS=["gemini-2.0-flash-exp"] # 支持搜索功能的模型列表
IMAGE_MODELS=["gemini-2.0-flash-exp"] # 支持绘图功能的模型列表
TOOLS_CODE_EXECUTION_ENABLED=false # 是否启用代码执行工具默认false
SHOW_SEARCH_LINK=true # 是否在响应中显示搜索结果链接默认true
SHOW_THINKING_PROCESS=true # 是否显示模型思考过程默认true
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"] # 被禁用的模型列表
# 图片生成配置
PAID_KEY="your-paid-api-key" # 付费版API Key用于图片生成等高级功能
@@ -115,9 +118,17 @@
#### 模型功能配置
- `MODEL_SEARCH`: 搜索功能支持的模型
- `TEST_MODEL`: 用于测试密钥可用性的模型
- 默认值: `"gemini-1.5-flash"`
- `SEARCH_MODELS`: 搜索功能支持的模型
- 默认值: `["gemini-2.0-flash-exp"]`
- 说明: 仅列表中的模型可使用搜索功能
- `IMAGE_MODELS`: 绘图功能支持的模型
- 默认值: `["gemini-2.0-flash-exp"]`
- 说明: 仅列表中的模型可使用绘图功能
- `FILTERED_MODELS`: 被禁用的模型列表
- 默认值: `["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]`
- 说明: 列表中的模型将被禁用
- `TOOLS_CODE_EXECUTION_ENABLED`: 代码执行功能
- 默认值: `false`
- 安全提示: 生产环境建议禁用
@@ -252,7 +263,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
"content": "你好"
}
],
"model": "gemini-1.5-flash-002",
"model": "gemini-1.5-flash",
"temperature": 0.7,
"stream": false,
"tools": [],
@@ -265,7 +276,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
- `messages`: 消息列表,格式与 OpenAI API 相同
- `model`: 模型名称支持所有Gemini模型包括:
- `gemini-1.5-flash-002`: 快速响应模型
- `gemini-1.5-flash`: 快速响应模型
- `gemini-2.0-flash-exp`: 实验性快速响应模型
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
- `stream`: 是否开启流式响应,`true` 或 `false`

57
app/config/config.py Normal file
View File

@@ -0,0 +1,57 @@
"""
应用程序配置模块
"""
from typing import List
from pydantic_settings import BaseSettings
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT
class Settings(BaseSettings):
"""应用程序配置"""
# API相关配置
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
TEST_MODEL: str = DEFAULT_MODEL
TIME_OUT: int = DEFAULT_TIMEOUT
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
# 图像生成相关配置
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
# 流式输出优化器配置
STREAM_OPTIMIZER_ENABLED: bool = False
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 设置默认AUTH_TOKEN如果未提供
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
class Config:
env_file = ".env"
# 创建全局配置实例
settings = Settings()

71
app/core/application.py Normal file
View File

@@ -0,0 +1,71 @@
"""
应用程序工厂模块负责创建和配置FastAPI应用程序实例
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from app.config.config import settings
from app.log.logger import get_application_logger
from app.middleware.middleware import setup_middlewares
from app.exception.exceptions import setup_exception_handlers
from app.router.routes import setup_routers
from app.service.key.key_manager import get_key_manager_instance
from app.core.initialization import initialize_app
logger = get_application_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用程序生命周期管理器
Args:
app: FastAPI应用实例
"""
# 启动事件
logger.info("Application starting up...")
try:
# 初始化KeyManager
await get_key_manager_instance(settings.API_KEYS)
logger.info("KeyManager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize KeyManager: {str(e)}")
raise
yield # 应用程序运行期间
# 关闭事件
logger.info("Application shutting down...")
def create_app() -> FastAPI:
"""
创建并配置FastAPI应用程序实例
Returns:
FastAPI: 配置好的FastAPI应用程序实例
"""
# 初始化应用程序
initialize_app()
# 创建FastAPI应用
app = FastAPI(
title="Gemini Balance API",
description="Gemini API代理服务支持负载均衡和密钥管理",
version="1.0.0",
lifespan=lifespan
)
# 配置静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# 配置中间件
setup_middlewares(app)
# 配置异常处理器
setup_exception_handlers(app)
# 配置路由
setup_routers(app)
return app

View File

@@ -1,41 +0,0 @@
from pydantic_settings import BaseSettings
from typing import List
class Settings(BaseSettings):
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
STREAM_MAX_DELAY: float = 0.024
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
def __init__(self):
super().__init__()
if not self.AUTH_TOKEN:
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
class Config:
env_file = ".env"
settings = Settings()

41
app/core/constants.py Normal file
View File

@@ -0,0 +1,41 @@
"""
常量定义模块
"""
# API相关常量
API_VERSION = "v1beta"
DEFAULT_TIMEOUT = 300 # 秒
# 模型相关常量
SUPPORTED_ROLES = ["user", "model", "system"]
DEFAULT_MODEL = "gemini-1.5-flash"
DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 8192
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 40
DEFAULT_FILTER_MODELS = [
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001"
]
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
# 图像生成相关常量
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# 上传提供商
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
DEFAULT_UPLOAD_PROVIDER = "smms"
# 流式输出相关常量
DEFAULT_STREAM_MIN_DELAY = 0.016
DEFAULT_STREAM_MAX_DELAY = 0.024
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
DEFAULT_STREAM_CHUNK_SIZE = 5
# 正则表达式模式
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'

View File

@@ -0,0 +1,40 @@
"""
应用程序初始化模块
"""
from pathlib import Path
from typing import List
from app.log.logger import get_initialization_logger
logger = get_initialization_logger()
def ensure_directories_exist(directories: List[str]) -> None:
"""
确保指定的目录存在,如果不存在则创建
Args:
directories: 要确保存在的目录列表
"""
for directory in directories:
try:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.info(f"Ensured directory exists: {directory}")
except Exception as e:
logger.error(f"Failed to create directory {directory}: {str(e)}")
def initialize_app() -> None:
"""
初始化应用程序,确保所需的目录和文件都存在
"""
# 确保必要的目录存在
required_directories = [
"app/static/css",
"app/static/js",
"app/static/icons",
"app/templates",
]
ensure_directories_exist(required_directories)
logger.info("Application initialization completed")

View File

@@ -1,13 +1,17 @@
from fastapi import HTTPException, Header
from typing import Optional
from app.core.logger import get_security_logger
from app.core.config import settings
from fastapi import Header, HTTPException
from app.config.config import settings
from app.log.logger import get_security_logger
logger = get_security_logger()
def verify_auth_token(token: str) -> bool:
return token == settings.AUTH_TOKEN
class SecurityService:
def __init__(self, allowed_tokens: list, auth_token: str):
self.allowed_tokens = allowed_tokens
@@ -20,7 +24,7 @@ class SecurityService:
return key
async def verify_authorization(
self, authorization: Optional[str] = Header(None)
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing Authorization header")
@@ -39,19 +43,26 @@ class SecurityService:
return token
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
async def verify_goog_api_key(
self, x_goog_api_key: Optional[str] = Header(None)
) -> str:
"""验证Google API Key"""
if not x_goog_api_key:
logger.error("Missing x-goog-api-key header")
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
if (
x_goog_api_key not in self.allowed_tokens
and x_goog_api_key != self.auth_token
):
logger.error("Invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
return x_goog_api_key
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
async def verify_auth_token(
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing auth_token header")
raise HTTPException(status_code=401, detail="Missing auth_token header")
@@ -61,3 +72,22 @@ class SecurityService:
raise HTTPException(status_code=401, detail="Invalid auth_token")
return token
async def verify_key_or_goog_api_key(
self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
) -> str:
"""验证URL中的key或请求头中的x-goog-api-key"""
# 如果URL中的key有效直接返回
if key in self.allowed_tokens or key == self.auth_token:
return key
# 否则检查请求头中的x-goog-api-key
if not x_goog_api_key:
logger.error("Invalid key and missing x-goog-api-key header")
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
logger.error("Invalid key and invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
return x_goog_api_key

View File

@@ -1,6 +1,8 @@
from typing import List, Optional, Dict, Any, Literal
from typing import List, Optional, Dict, Any, Literal, Union
from pydantic import BaseModel
from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
class SafetySetting(BaseModel):
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
@@ -13,9 +15,9 @@ class GenerationConfig(BaseModel):
responseSchema: Optional[Dict[str, Any]] = None
candidateCount: Optional[int] = 1
maxOutputTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
topK: Optional[int] = None
temperature: Optional[float] = DEFAULT_TEMPERATURE
topP: Optional[float] = DEFAULT_TOP_P
topK: Optional[int] = DEFAULT_TOP_K
presencePenalty: Optional[float] = None
frequencyPenalty: Optional[float] = None
responseLogprobs: Optional[bool] = None
@@ -34,7 +36,7 @@ class GeminiContent(BaseModel):
class GeminiRequest(BaseModel):
contents: List[GeminiContent] = []
tools: Optional[List[Dict[str, Any]]] = []
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = {}
generationConfig: Optional[GenerationConfig] = None
systemInstruction: Optional[SystemInstruction] = None

View File

@@ -1,17 +1,19 @@
from pydantic import BaseModel
from typing import List, Optional, Union
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
model: str = DEFAULT_MODEL
temperature: Optional[float] = DEFAULT_TEMPERATURE
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
max_tokens: Optional[int] = 8192
max_tokens: Optional[int] = None
top_p: Optional[float] = DEFAULT_TOP_P
top_k: Optional[int] = DEFAULT_TOP_K
stop: Optional[List[str]] = []
top_p: Optional[float] = 0.9
top_k: Optional[int] = 40
class EmbeddingRequest(BaseModel):

140
app/exception/exceptions.py Normal file
View File

@@ -0,0 +1,140 @@
"""
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
"""
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.log.logger import get_exceptions_logger
logger = get_exceptions_logger()
class APIError(Exception):
"""API错误基类"""
def __init__(self, status_code: int, detail: str, error_code: str = None):
self.status_code = status_code
self.detail = detail
self.error_code = error_code or "api_error"
super().__init__(self.detail)
class AuthenticationError(APIError):
"""认证错误"""
def __init__(self, detail: str = "Authentication failed"):
super().__init__(
status_code=401, detail=detail, error_code="authentication_error"
)
class AuthorizationError(APIError):
"""授权错误"""
def __init__(self, detail: str = "Not authorized to access this resource"):
super().__init__(
status_code=403, detail=detail, error_code="authorization_error"
)
class ResourceNotFoundError(APIError):
"""资源未找到错误"""
def __init__(self, detail: str = "Resource not found"):
super().__init__(
status_code=404, detail=detail, error_code="resource_not_found"
)
class ModelNotSupportedError(APIError):
"""模型不支持错误"""
def __init__(self, model: str):
super().__init__(
status_code=400,
detail=f"Model {model} is not supported",
error_code="model_not_supported",
)
class APIKeyError(APIError):
"""API密钥错误"""
def __init__(self, detail: str = "Invalid or expired API key"):
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
class ServiceUnavailableError(APIError):
"""服务不可用错误"""
def __init__(self, detail: str = "Service temporarily unavailable"):
super().__init__(
status_code=503, detail=detail, error_code="service_unavailable"
)
def setup_exception_handlers(app: FastAPI) -> None:
"""
设置应用程序的异常处理器
Args:
app: FastAPI应用程序实例
"""
@app.exception_handler(APIError)
async def api_error_handler(request: Request, exc: APIError):
"""处理API错误"""
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
return JSONResponse(
status_code=exc.status_code,
content={"error": {"code": exc.error_code, "message": exc.detail}},
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""处理HTTP异常"""
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
return JSONResponse(
status_code=exc.status_code,
content={"error": {"code": "http_error", "message": exc.detail}},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
"""处理请求验证错误"""
error_details = []
for error in exc.errors():
error_details.append(
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
)
logger.error(f"Validation Error: {error_details}")
return JSONResponse(
status_code=422,
content={
"error": {
"code": "validation_error",
"message": "Request validation failed",
"details": error_details,
}
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理通用异常"""
logger.exception(f"Unhandled Exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={
"error": {
"code": "internal_server_error",
"message": "An unexpected error occurred",
}
},
)

View File

@@ -1,13 +1,13 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
import json
import re
from typing import Any, Dict, List, Optional
import requests
import base64
SUPPORTED_ROLES = ["user", "model", "system"]
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
class MessageConverter(ABC):
@@ -30,10 +30,10 @@ def _get_mime_type_and_data(base64_string):
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = r'data:([^;]+);base64,(.+)'
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = match.group(1)
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
@@ -49,11 +49,14 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
"data": encoded_data
}
}
return {
"image_url": {
"url": image_url
else:
encoded_data = _convert_image_to_base64(image_url)
return {
"inline_data": {
"mime_type": "image/png",
"data": encoded_data
}
}
}
def _convert_image_to_base64(url: str) -> str:
@@ -87,7 +90,7 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(1)
img_url = img_url_match.group(2)
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
@@ -115,6 +118,36 @@ class OpenAIMessageConverter(MessageConverter):
for idx, msg in enumerate(messages):
role = msg.get("role", "")
parts = []
# 特别处理最后一个assistant的消息按\n\n分割
if "content" in msg and isinstance(msg["content"], str) and msg["content"] and role == "assistant" and idx == len(messages) - 2:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.extend(_process_text_with_image(msg["content"]))
elif "content" in msg and isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text" and content["text"]:
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
for tool_call in msg["tool_calls"]:
function_call = tool_call.get("function",{})
function_call["args"] = json.loads(function_call.get("arguments","{}"))
del function_call["arguments"]
parts.append({"functionCall": function_call})
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user"
@@ -124,30 +157,6 @@ class OpenAIMessageConverter(MessageConverter):
role = "user"
else:
role = "model"
parts = []
# 特别处理最后一个assistant的消息按\n\n分割
if role == "assistant" and idx == len(messages) - 2 and isinstance(msg["content"], str) and msg["content"]:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.extend(_process_text_with_image(msg["content"]))
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text" and content["text"]:
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
if parts:
if role == "system":
system_instruction_parts.extend(parts)

View File

@@ -8,8 +8,8 @@ from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import time
import uuid
from app.core.config import settings
from app.core.uploader import ImageUploaderFactory
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
class ResponseHandler(ABC):
@@ -205,11 +205,11 @@ def _extract_image_data(part: dict) -> str:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
#将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success:
text = f"![image]({upload_response.data.url})"
text = f"\n\n![image]({upload_response.data.url})\n\n"
else:
text = ""
return text

View File

@@ -1,10 +1,11 @@
# app/services/chat/retry_handler.py
from typing import TypeVar, Callable
from functools import wraps
from app.core.logger import get_retry_logger
from typing import Callable, TypeVar
T = TypeVar('T')
from app.log.logger import get_retry_logger
T = TypeVar("T")
logger = get_retry_logger()
@@ -25,17 +26,21 @@ class RetryHandler:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
logger.warning(
f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}"
)
# 从函数参数中获取 key_manager
key_manager = kwargs.get('key_manager')
key_manager = kwargs.get("key_manager")
if key_manager:
old_key = kwargs.get(self.key_arg)
new_key = await key_manager.handle_api_failure(old_key)
kwargs[self.key_arg] = new_key
logger.info(f"Switched to new API key: {new_key}")
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
logger.error(
f"All retry attempts failed, raising final exception: {str(last_exception)}"
)
raise last_exception
return wrapper

View File

@@ -2,9 +2,17 @@
import asyncio
import math
from typing import Any, List, AsyncGenerator, Callable
from app.core.logger import get_openai_logger, get_gemini_logger
from app.core.config import settings
from typing import Any, AsyncGenerator, Callable, List
from app.config.config import settings
from app.core.constants import (
DEFAULT_STREAM_CHUNK_SIZE,
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
DEFAULT_STREAM_MAX_DELAY,
DEFAULT_STREAM_MIN_DELAY,
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
)
from app.log.logger import get_gemini_logger, get_openai_logger
logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger()
@@ -12,19 +20,21 @@ logger_gemini = get_gemini_logger()
class StreamOptimizer:
"""流式输出优化器
提供流式输出优化功能包括智能延迟调整和长文本分块输出
"""
def __init__(self,
logger=None,
min_delay: float = 0.016,
max_delay: float = 0.024,
short_text_threshold: int = 10,
long_text_threshold: int = 50,
chunk_size: int = 5):
def __init__(
self,
logger=None,
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
):
"""初始化流式输出优化器
参数:
logger: 日志记录器
min_delay: 最小延迟时间
@@ -39,13 +49,13 @@ class StreamOptimizer:
self.short_text_threshold = short_text_threshold
self.long_text_threshold = long_text_threshold
self.chunk_size = chunk_size
def calculate_delay(self, text_length: int) -> float:
"""根据文本长度计算延迟时间
参数:
text_length: 文本长度
返回:
延迟时间
"""
@@ -58,48 +68,54 @@ class StreamOptimizer:
else:
# 中等长度文本使用线性插值计算延迟
# 使用对数函数使延迟变化更平滑
ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold)
ratio = math.log(text_length / self.short_text_threshold) / math.log(
self.long_text_threshold / self.short_text_threshold
)
return self.max_delay - ratio * (self.max_delay - self.min_delay)
def split_text_into_chunks(self, text: str) -> List[str]:
"""将文本分割成小块
参数:
text: 要分割的文本
返回:
文本块列表
"""
return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)]
async def optimize_stream_output(self,
text: str,
create_response_chunk: Callable[[str], Any],
format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]:
return [
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
]
async def optimize_stream_output(
self,
text: str,
create_response_chunk: Callable[[str], Any],
format_chunk: Callable[[Any], str],
) -> AsyncGenerator[str, None]:
"""优化流式输出
参数:
text: 要输出的文本
create_response_chunk: 创建响应块的函数接收文本返回响应块
format_chunk: 格式化响应块的函数接收响应块返回格式化后的字符串
返回:
异步生成器生成格式化后的响应块
"""
if not text:
return
# 计算智能延迟时间
delay = self.calculate_delay(len(text))
if self.logger:
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
# if self.logger:
# self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
# 根据文本长度决定输出方式
if len(text) >= self.long_text_threshold:
# 长文本:分块输出
chunks = self.split_text_into_chunks(text)
if self.logger:
self.logger.info(f"Long text: splitting into {len(chunks)} chunks")
# if self.logger:
# self.logger.info(f"Long text: splitting into {len(chunks)} chunks")
for chunk_text in chunks:
chunk_response = create_response_chunk(chunk_text)
yield format_chunk(chunk_response)
@@ -119,7 +135,7 @@ openai_optimizer = StreamOptimizer(
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
chunk_size=settings.STREAM_CHUNK_SIZE,
)
gemini_optimizer = StreamOptimizer(
@@ -128,5 +144,5 @@ gemini_optimizer = StreamOptimizer(
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
chunk_size=settings.STREAM_CHUNK_SIZE,
)

View File

@@ -133,3 +133,23 @@ def get_retry_logger():
def get_image_create_logger():
return Logger.setup_logger("image_create")
def get_exceptions_logger():
return Logger.setup_logger("exceptions")
def get_application_logger():
return Logger.setup_logger("application")
def get_initialization_logger():
return Logger.setup_logger("initialization")
def get_middleware_logger():
return Logger.setup_logger("middleware")
def get_routes_logger():
return Logger.setup_logger("routes")

View File

@@ -1,134 +1,18 @@
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from app.core.logger import get_main_logger
from app.core.security import verify_auth_token
from app.services.key_manager import get_key_manager_instance
from app.core.config import settings
"""
应用程序入口模块
"""
from app.api import gemini_routes, openai_routes
import uvicorn
from app.core.application import create_app
from app.log.logger import get_main_logger
# 创建应用程序实例
app = create_app()
# 配置日志
logger = get_main_logger()
app = FastAPI()
# 配置Jinja2模板
templates = Jinja2Templates(directory="app/templates")
# 配置静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# 创建 KeyManager 实例
key_manager = None
@app.on_event("startup")
async def startup_event():
global key_manager
logger.info("Application starting up...")
try:
key_manager = await get_key_manager_instance(settings.API_KEYS)
logger.info("KeyManager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize KeyManager: {str(e)}")
raise
# 添加中间件来处理未经身份验证的请求
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
# 允许 gemini_routes 和 openai_routes 中的端点绕过身份验证
if (request.url.path not in ["/", "/auth"] and
not request.url.path.startswith("/static") and
not request.url.path.startswith("/gemini") and
not request.url.path.startswith("/v1") and
not request.url.path.startswith("/v1beta") and
not request.url.path.startswith("/health") and
not request.url.path.startswith("/hf")):
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning(f"Unauthorized access attempt to {request.url.path}")
return RedirectResponse(url="/")
logger.debug("Request authenticated successfully")
response = await call_next(request)
return response
# 添加请求日志中间件
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)
# 包含所有路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
@app.get("/", response_class=HTMLResponse)
async def auth_page(request: Request):
return templates.TemplateResponse("auth.html", {"request": request})
@app.post("/auth")
async def authenticate(request: Request):
try:
form = await request.form()
auth_token = form.get("auth_token")
if not auth_token:
logger.warning("Authentication attempt with empty token")
return RedirectResponse(url="/", status_code=302)
if verify_auth_token(auth_token):
logger.info("Successful authentication")
response = RedirectResponse(url="/keys", status_code=302)
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
return response
logger.warning("Failed authentication attempt with invalid token")
return RedirectResponse(url="/", status_code=302)
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return RedirectResponse(url="/", status_code=302)
@app.get("/keys", response_class=HTMLResponse)
async def keys_page(request: Request):
try:
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning("Unauthorized access attempt to keys page")
return RedirectResponse(url="/", status_code=302)
keys_status = await key_manager.get_keys_by_status()
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
return templates.TemplateResponse("keys_status.html", {
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total
})
except Exception as e:
logger.error(f"Error retrieving keys status: {str(e)}")
raise
@app.get("/health")
async def health_check(request: Request):
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
logger.info("Starting application server...")
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@@ -0,0 +1,73 @@
"""
中间件配置模块,负责设置和配置应用程序的中间件
"""
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.core.constants import API_VERSION
from app.core.security import verify_auth_token
from app.log.logger import get_middleware_logger
logger = get_middleware_logger()
class AuthMiddleware(BaseHTTPMiddleware):
"""
认证中间件,处理未经身份验证的请求
"""
async def dispatch(self, request: Request, call_next):
# 允许特定路径绕过身份验证
if (
request.url.path not in ["/", "/auth"]
and not request.url.path.startswith("/static")
and not request.url.path.startswith("/gemini")
and not request.url.path.startswith("/v1")
and not request.url.path.startswith(f"/{API_VERSION}")
and not request.url.path.startswith("/health")
and not request.url.path.startswith("/hf")
):
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning(f"Unauthorized access attempt to {request.url.path}")
return RedirectResponse(url="/")
logger.debug("Request authenticated successfully")
response = await call_next(request)
return response
def setup_middlewares(app: FastAPI) -> None:
"""
设置应用程序的中间件
Args:
app: FastAPI应用程序实例
"""
# 添加认证中间件
app.add_middleware(AuthMiddleware)
# 添加请求日志中间件(可选,默认注释掉)
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=[
"GET",
"POST",
"PUT",
"DELETE",
"OPTIONS",
], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)

View File

@@ -1,7 +1,9 @@
import json
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import json
from app.core.logger import get_request_logger
from app.log.logger import get_request_logger
logger = get_request_logger()
@@ -20,7 +22,9 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# 尝试格式化JSON
try:
formatted_body = json.loads(body_str)
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
logger.info(
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
)
except json.JSONDecodeError:
logger.info("Request body is not valid JSON.")
except Exception as e:

View File

@@ -1,73 +1,80 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.config.config import settings
from app.log.logger import get_gemini_logger
from app.core.security import SecurityService
from app.schemas.gemini_models import GeminiContent, GeminiRequest
from app.services.gemini_chat_service import GeminiChatService
from app.services.key_manager import KeyManager, get_key_manager_instance
from app.services.model_service import ModelService
from app.services.chat.retry_handler import RetryHandler
from app.domain.gemini_models import GeminiContent, GeminiRequest
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.core.constants import API_VERSION
router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
# 路由设置
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
logger = get_gemini_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
async def get_key_manager():
"""获取密钥管理器实例"""
return await get_key_manager_instance()
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
return await key_manager.get_next_working_key()
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
"""获取下一个可用的API密钥"""
return await key_manager.get_next_working_key()
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(_=Depends(security_service.verify_key),
key_manager: KeyManager = Depends(get_key_manager)):
async def list_models(
_=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_next_working_key()
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
# 模型名称以及对应的详细信息
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
# 添加搜索模型
if settings.MODEL_SEARCH:
for name in settings.MODEL_SEARCH:
model = model_mapping.get(name, None)
if model_service.search_models:
for name in model_service.search_models:
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.MODEL_IMAGE:
for name in settings.MODEL_IMAGE:
model = model_mapping.get(name, None)
if model_service.image_models:
for name in model_service.image_models:
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json
@@ -77,30 +84,29 @@ async def list_models(_=Depends(security_service.verify_key),
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager)
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager)
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@@ -110,45 +116,46 @@ async def generate_content(
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(max_retries=3, key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager)
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager)
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Streaming request failed") from e
@router.post("/verify-key/{api_key}")
async def verify_key(api_key: str):
key_manager = await get_key_manager()
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""验证Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
logger.info("Verifying API key validity")
try:
key_manager = await get_key_manager()
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
# 使用generate_content接口测试key的有效性
gemini_requset = GeminiRequest(
gemini_request = GeminiRequest(
contents=[
GeminiContent(
role="user",
@@ -156,10 +163,16 @@ async def verify_key(api_key: str):
)
]
)
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
response = await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
)
if response:
return JSONResponse({"status": "valid"})
return JSONResponse({"status": "invalid"})
except Exception as e:
logger.error(f"Key verification failed: {str(e)}")
return JSONResponse({"status": "invalid", "error": str(e)})
return JSONResponse({"status": "invalid", "error": str(e)})

View File

@@ -1,47 +1,58 @@
from fastapi import HTTPException, APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.core.logger import get_openai_logger
from app.config.config import settings
from app.core.security import SecurityService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
from app.services.chat.retry_handler import RetryHandler
from app.services.embedding_service import EmbeddingService
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager, get_key_manager_instance
from app.services.model_service import ModelService
from app.services.openai_chat_service import OpenAIChatService
from app.domain.openai_models import (
ChatRequest,
EmbeddingRequest,
ImageGenerationRequest,
)
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
router = APIRouter()
logger = get_openai_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService()
async def get_key_manager():
return await get_key_manager_instance()
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
async def get_next_working_key_wrapper(
key_manager: KeyManager = Depends(get_key_manager),
):
return await key_manager.get_next_working_key()
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "list_models" + "-" * 50)
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
try:
return model_service.get_gemini_openai_models(api_key)
except Exception as e:
logger.error(f"Error getting models list: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
raise HTTPException(
status_code=500, detail="Internal server error while fetching models list"
) from e
@router.post("/v1/chat/completions")
@@ -51,7 +62,7 @@ async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
@@ -63,8 +74,10 @@ async def chat_completion(
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {request.model} is not supported"
)
try:
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
@@ -80,6 +93,7 @@ async def chat_completion(
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/v1/images/generations")
@router.post("/hf/v1/images/generations")
async def generate_image(
@@ -95,14 +109,17 @@ async def generate_image(
return response
except Exception as e:
logger.error(f"Image generation request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Image generation request failed") from e
raise HTTPException(
status_code=500, detail="Image generation request failed"
) from e
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "embedding" + "-" * 50)
logger.info(f"Handling embedding request for model: {request.model}")
@@ -118,11 +135,12 @@ async def embedding(
logger.error(f"Embedding request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Embedding request failed") from e
@router.get("/v1/keys/list")
@router.get("/hf/v1/keys/list")
async def get_keys_list(
_=Depends(security_service.verify_auth_token),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取有效和无效的API key列表"""
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
@@ -133,13 +151,12 @@ async def get_keys_list(
"status": "success",
"data": {
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"]
"invalid_keys": keys_status["invalid_keys"],
},
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
}
except Exception as e:
logger.error(f"Error getting keys list: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error while fetching keys list"
status_code=500, detail="Internal server error while fetching keys list"
) from e

114
app/router/routes.py Normal file
View File

@@ -0,0 +1,114 @@
"""
路由配置模块,负责设置和配置应用程序的路由
"""
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from app.core.security import verify_auth_token
from app.log.logger import get_routes_logger
from app.router import gemini_routes, openai_routes
from app.service.key.key_manager import get_key_manager_instance
logger = get_routes_logger()
# 配置Jinja2模板
templates = Jinja2Templates(directory="app/templates")
def setup_routers(app: FastAPI) -> None:
"""
设置应用程序的路由
Args:
app: FastAPI应用程序实例
"""
# 包含API路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
# 添加页面路由
setup_page_routes(app)
# 添加健康检查路由
setup_health_routes(app)
def setup_page_routes(app: FastAPI) -> None:
"""
设置页面相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/", response_class=HTMLResponse)
async def auth_page(request: Request):
"""认证页面"""
return templates.TemplateResponse("auth.html", {"request": request})
@app.post("/auth")
async def authenticate(request: Request):
"""处理认证请求"""
try:
form = await request.form()
auth_token = form.get("auth_token")
if not auth_token:
logger.warning("Authentication attempt with empty token")
return RedirectResponse(url="/", status_code=302)
if verify_auth_token(auth_token):
logger.info("Successful authentication")
response = RedirectResponse(url="/keys", status_code=302)
response.set_cookie(
key="auth_token", value=auth_token, httponly=True, max_age=3600
)
return response
logger.warning("Failed authentication attempt with invalid token")
return RedirectResponse(url="/", status_code=302)
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return RedirectResponse(url="/", status_code=302)
@app.get("/keys", response_class=HTMLResponse)
async def keys_page(request: Request):
"""密钥管理页面"""
try:
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning("Unauthorized access attempt to keys page")
return RedirectResponse(url="/", status_code=302)
key_manager = await get_key_manager_instance()
keys_status = await key_manager.get_keys_by_status()
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
return templates.TemplateResponse(
"keys_status.html",
{
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total,
},
)
except Exception as e:
logger.error(f"Error retrieving keys status: {str(e)}")
raise
def setup_health_routes(app: FastAPI) -> None:
"""
设置健康检查相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/health")
async def health_check(request: Request):
"""健康检查端点"""
logger.info("Health check endpoint called")
return {"status": "healthy"}

View File

@@ -1,14 +1,15 @@
# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List
from app.core.logger import get_gemini_logger
from app.services.chat.api_client import GeminiApiClient
from app.services.chat.stream_optimizer import gemini_optimizer
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.services.chat.response_handler import GeminiResponseHandler
from app.services.key_manager import KeyManager
from typing import Any, AsyncGenerator, Dict, List
from app.config.config import settings
from app.domain.gemini_models import GeminiRequest
from app.handler.response_handler import GeminiResponseHandler
from app.handler.stream_optimizer import gemini_optimizer
from app.log.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
logger = get_gemini_logger()
@@ -25,20 +26,45 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not _has_image_parts(payload.get("contents", [])):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
record = dict()
for item in tools:
if not item or not isinstance(item, dict):
continue
for k, v in item.items():
if k == "functionDeclarations" and v and isinstance(v, list):
functions = record.get("functionDeclarations", [])
functions.extend(v)
record["functionDeclarations"] = functions
else:
record[k] = v
return record
tool = dict()
if payload and isinstance(payload, dict) and "tools" in payload:
if payload.get("tools") and isinstance(payload.get("tools"), dict):
payload["tools"] = [payload.get("tools")]
items = payload.get("tools", [])
if items and isinstance(items, list):
tools.extend(items)
tool.update(_merge_tools(items))
return tools
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(payload.get("contents", []))
):
tool["codeExecution"] = {}
if model.endswith("-search"):
tool["googleSearch"] = {}
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
return [tool] if tool else []
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
@@ -49,31 +75,36 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
request_dict = request.model_dump()
if request.generationConfig:
if request.generationConfig.maxOutputTokens is None:
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
request_dict["generationConfig"].pop("maxOutputTokens")
payload = {
"contents": request_dict.get("contents", []),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": request_dict.get("systemInstruction", "")
"systemInstruction": request_dict.get("systemInstruction", ""),
}
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
return payload
@@ -84,54 +115,67 @@ class GeminiChatService:
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
"""从响应中提取文本内容"""
if not response.get("candidates"):
return ""
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if parts and "text" in parts[0]:
return parts[0].get("text", "")
return ""
def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]:
def _create_char_response(
self, original_response: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的响应"""
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"):
if response_copy.get("candidates") and response_copy["candidates"][0].get(
"content", {}
).get("parts"):
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
return response_copy
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
async def generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
async def stream_generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = _build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
line = line[6:]
response_data = self.response_handler.handle_response(json.loads(line), model, stream=True)
response_data = self.response_handler.handle_response(
json.loads(line), model, stream=True
)
text = self._extract_text_from_response(response_data)
# 如果有文本内容,使用流式输出优化器处理
if text:
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
if text and settings.STREAM_OPTIMIZER_ENABLED:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in gemini_optimizer.optimize_stream_output(
async for (
optimized_chunk
) in gemini_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_response(response_data, t),
lambda c: "data: " + json.dumps(c) + "\n\n"
lambda c: "data: " + json.dumps(c) + "\n\n",
):
yield optimized_chunk
else:
@@ -141,9 +185,13 @@ class GeminiChatService:
break
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
logger.warning(
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
)
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
logger.error(
f"Max retries ({max_retries}) reached for streaming. Raising error"
)
break

View File

@@ -1,17 +1,18 @@
# app/services/chat_service.py
from copy import deepcopy
import json
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.services.chat.stream_optimizer import openai_optimizer
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
from app.core.config import settings
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
from app.handler.stream_optimizer import openai_optimizer
from app.log.logger import get_openai_logger
from app.service.client.api_client import GeminiApiClient
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
logger = get_openai_logger()
@@ -27,30 +28,35 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
tool = dict()
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation"))
and not _has_image_parts(messages)
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (
model.endswith("-search")
or "-thinking" in model
or model.endswith("-image")
or model.endswith("-image-generation")
)
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
tool["codeExecution"] = {}
if model.endswith("-search"):
tools.append({"googleSearch": {}})
tool["googleSearch"] = {}
# 将 request 中的 tools 合并到 tools 中
if request.tools:
function_declarations = []
for tool in request.tools:
if not tool or not isinstance(tool, dict):
for item in request.tools:
if not item or not isinstance(item, dict):
continue
if tool.get("type", "") == "function" and tool.get("function"):
function = deepcopy(tool.get("function"))
if item.get("type", "") == "function" and item.get("function"):
function = deepcopy(item.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get("properties", {}):
function.pop("parameters", None)
@@ -60,14 +66,19 @@ def _build_tools(
if function_declarations:
# 按照 function 的 name 去重
names, functions = set(), []
for item in function_declarations:
if item.get("name") not in names:
names.add(item.get("name"))
functions.append(item)
for fc in function_declarations:
if fc.get("name") not in names:
names.add(fc.get("name"))
functions.append(fc)
tools.append({"functionDeclarations": functions})
return tools
tool["functionDeclarations"] = functions
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
return [tool] if tool else []
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
@@ -95,14 +106,15 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
request: ChatRequest,
messages: List[Dict[str, Any]],
instruction: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""构建请求payload"""
payload = {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k,
@@ -110,9 +122,11 @@ def _build_payload(
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
if request.max_tokens is not None:
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if (
instruction
and isinstance(instruction, dict)
@@ -128,24 +142,27 @@ def _build_payload(
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.message_converter = OpenAIMessageConverter()
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
"""从OpenAI响应块中提取文本内容"""
if not chunk.get("choices"):
return ""
choice = chunk["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
return choice["delta"]["content"]
return ""
def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]:
def _create_char_openai_chunk(
self, original_chunk: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的OpenAI响应块"""
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
@@ -153,9 +170,9 @@ class OpenAIChatService:
return chunk_copy
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
@@ -169,7 +186,7 @@ class OpenAIChatService:
return await self._handle_normal_completion(request.model, payload, api_key)
async def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = await self.api_client.generate_content(payload, model, api_key)
@@ -178,15 +195,16 @@ class OpenAIChatService:
)
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = 3
while retries < max_retries:
try:
tool_call_flag = False
async for line in self.api_client.stream_generate_content(
payload, model, api_key
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
@@ -197,18 +215,27 @@ class OpenAIChatService:
if openai_chunk:
# 提取文本内容
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
if text and settings.STREAM_OPTIMIZER_ENABLED:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in openai_optimizer.optimize_stream_output(
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n"
lambda t: self._create_char_openai_chunk(
openai_chunk, t
),
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
# 如果没有文本内容(如工具调用等),整块输出
if "tool_calls" in json.dumps(openai_chunk):
tool_call_flag = True
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
if tool_call_flag:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls'))}\n\n"
else:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
break # 成功后退出循环
@@ -228,21 +255,23 @@ class OpenAIChatService:
break
async def create_image_chat_completion(
self,
request: ChatRequest,
self,
request: ChatRequest,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
image_generate_request.prompt = request.messages[-1]["content"]
image_res = self.image_create_service.generate_images_chat(image_generate_request)
image_res = self.image_create_service.generate_images_chat(
image_generate_request
)
if request.stream:
return self._handle_stream_image_completion(request.model,image_res)
return self._handle_stream_image_completion(request.model, image_res)
else:
return self._handle_normal_image_completion(request.model,image_res)
return self._handle_normal_image_completion(request.model, image_res)
async def _handle_stream_image_completion(
self, model: str, image_data: str
self, model: str, image_data: str
) -> AsyncGenerator[str, None]:
if image_data:
openai_chunk = self.response_handler.handle_image_chat_response(
@@ -253,10 +282,12 @@ class OpenAIChatService:
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in openai_optimizer.optimize_stream_output(
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n"
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
@@ -265,11 +296,11 @@ class OpenAIChatService:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Image chat streaming completed successfully")
def _handle_normal_image_completion(
self, model: str, image_data: str
self, model: str, image_data: str
) -> Dict[str, Any]:
return self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)

View File

@@ -4,6 +4,8 @@ from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod
from app.core.constants import DEFAULT_TIMEOUT
class ApiClient(ABC):
"""API客户端基类"""
@@ -20,7 +22,7 @@ class ApiClient(ABC):
class GeminiApiClient(ApiClient):
"""Gemini API客户端"""
def __init__(self, base_url: str, timeout: int = 300):
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
self.base_url = base_url
self.timeout = timeout

View File

@@ -1,9 +1,9 @@
from typing import Union, List
from typing import List, Union
import openai
from openai.types import CreateEmbeddingResponse
from app.core.logger import get_embeddings_logger
from app.log.logger import get_embeddings_logger
logger = get_embeddings_logger()

View File

@@ -1,14 +1,15 @@
import base64
import time
import uuid
from google import genai
from google.genai import types
import base64
from app.core.config import settings
from app.core.logger import get_image_create_logger
from app.core.uploader import ImageUploaderFactory
from app.schemas.openai_models import ImageGenerationRequest
from app.config.config import settings
from app.core.constants import VALID_IMAGE_RATIOS
from app.domain.openai_models import ImageGenerationRequest
from app.log.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory
logger = get_image_create_logger()
@@ -26,35 +27,34 @@ class ImageCreateService:
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
"""
import re
# 默认值
n = 1
aspect_ratio = self.aspect_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
n_match = re.search(r"{n:(\d+)}", prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
prompt = prompt.replace(n_match.group(0), "").strip()
# 解析ratio参数
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"]
if aspect_ratio not in valid_ratios:
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}"
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
prompt = prompt.replace(ratio_match.group(0), "").strip()
return prompt, n, aspect_ratio
def generate_images(self, request: ImageGenerationRequest):
client = genai.Client(api_key=self.paid_key)
if request.size == "1024x1024":
self.aspect_ratio = "1:1"
elif request.size == "1792x1024":
@@ -67,13 +67,15 @@ class ImageCreateService:
)
# 解析prompt中的参数
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt)
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
request.prompt
)
request.prompt = cleaned_prompt
# 如果prompt中指定了n则覆盖请求中的n
if prompt_n > 1:
request.n = prompt_n
# 如果prompt中指定了ratio则覆盖默认的aspect_ratio
if prompt_ratio != self.aspect_ratio:
self.aspect_ratio = prompt_ratio
@@ -96,46 +98,49 @@ class ImageCreateService:
for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes
image_uploader = None
if request.response_format == "b64_json":
base64_image = base64.b64encode(image_data).decode('utf-8')
images_data.append({
"b64_json": base64_image,
"revised_prompt": request.prompt
})
base64_image = base64.b64encode(image_data).decode("utf-8")
images_data.append(
{"b64_json": base64_image, "revised_prompt": request.prompt}
)
else:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.SMMS_SECRET_TOKEN
api_key=settings.SMMS_SECRET_TOKEN,
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY
api_key=settings.PICGO_API_KEY,
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
)
else:
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
raise ValueError(
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
)
upload_response = image_uploader.upload(image_data, filename)
images_data.append({
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt
})
images_data.append(
{
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt,
}
)
response_data = {
"created": int(time.time()), # Current timestamp
"data": images_data
"data": images_data,
}
return response_data
else:
@@ -147,9 +152,13 @@ class ImageCreateService:
if image_datas:
markdown_images = []
for index, image_data in enumerate(image_datas):
if 'url' in image_data:
markdown_images.append(f"![Generated Image {index+1}]({image_data['url']})")
if "url" in image_data:
markdown_images.append(
f"![Generated Image {index+1}]({image_data['url']})"
)
else:
# 如果是base64格式创建data URL
markdown_images.append(f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})")
markdown_images.append(
f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})"
)
return "\n".join(markdown_images)

View File

@@ -1,9 +1,9 @@
import asyncio
from itertools import cycle
from typing import Dict
from app.core.logger import get_key_manager_logger
from app.core.config import settings
from app.config.config import settings
from app.log.logger import get_key_manager_logger
logger = get_key_manager_logger()
@@ -20,7 +20,7 @@ class KeyManager:
async def get_paid_key(self) -> str:
return self.paid_key
async def get_next_key(self) -> str:
"""获取下一个API key"""
async with self.key_cycle_lock:
@@ -70,7 +70,7 @@ class KeyManager:
"""获取分类后的API key列表包括失败次数"""
valid_keys = {}
invalid_keys = {}
async with self.failure_count_lock:
for key in self.api_keys:
fail_count = self.key_failure_counts[key]
@@ -78,16 +78,21 @@ class KeyManager:
valid_keys[key] = fail_count
else:
invalid_keys[key] = fail_count
return {
"valid_keys": valid_keys,
"invalid_keys": invalid_keys
}
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
async def get_first_valid_key(self) -> str:
"""获取第一个有效的API key"""
async with self.failure_count_lock:
for key in self.key_failure_counts:
if self.key_failure_counts[key] < self.MAX_FAILURES:
return key
return self.api_keys[0]
_singleton_instance = None
_singleton_lock = asyncio.Lock()
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
"""
获取 KeyManager 单例实例

View File

@@ -1,16 +1,20 @@
import requests
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from app.core.logger import get_model_logger
from app.core.config import settings
from typing import Any, Dict, Optional
import requests
from app.config.config import settings
from app.log.logger import get_model_logger
logger = get_model_logger()
class ModelService:
def __init__(self, model_search: list, model_image: list):
self.model_search = model_search
self.model_image = model_image
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def __init__(self, search_models: list, image_models: list):
self.search_models = search_models
self.image_models = image_models
self.base_url = settings.BASE_URL
self.filtered_models = settings.FILTERED_MODELS
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
url = f"{self.base_url}/models?key={api_key}"
@@ -19,6 +23,16 @@ class ModelService:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
filtered_models_list = []
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
if model_id not in self.filtered_models:
filtered_models_list.append(model)
else:
logger.info(f"Filtered out model: {model_id}")
gemini_models["models"] = filtered_models_list
return gemini_models
else:
logger.error(f"Error: {response.status_code}")
@@ -37,7 +51,7 @@ class ModelService:
return None
def convert_to_openai_models_format(
self, gemini_models: Dict[str, Any]
self, gemini_models: Dict[str, Any]
) -> Dict[str, Any]:
openai_format = {"object": "list", "data": [], "success": True}
@@ -54,11 +68,11 @@ class ModelService:
}
openai_format["data"].append(openai_model)
if model_id in self.model_search:
if model_id in self.search_models:
search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model)
if model_id in self.model_image:
if model_id in self.image_models:
image_model = openai_model.copy()
image_model["id"] = f"{model_id}-image"
openai_format["data"].append(image_model)
@@ -76,9 +90,9 @@ class ModelService:
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in settings.MODEL_SEARCH
return model in self.search_models
if model.endswith("-image"):
model = model[:-6]
return model in settings.MODEL_IMAGE
return model in self.image_models
return True
return model not in self.filtered_models

3
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
工具包初始化模块
"""

146
app/utils/helpers.py Normal file
View File

@@ -0,0 +1,146 @@
"""
通用工具函数模块
"""
import json
import re
import base64
import requests
from typing import Dict, Any, List, Optional, Tuple
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
"""
从 base64 字符串中提取 MIME 类型和数据
Args:
base64_string: 可能包含 MIME 类型信息的 base64 字符串
Returns:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
Raises:
Exception: 如果获取图片失败
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
"""
格式化JSON响应
Args:
data: 要格式化的数据
indent: 缩进空格数
Returns:
str: 格式化后的JSON字符串
"""
return json.dumps(data, indent=indent, ensure_ascii=False)
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
"""
从prompt中解析参数
支持的格式:
- {n:数量} 例如: {n:2} 生成2张图片
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
Args:
prompt: 提示文本
default_ratio: 默认比例
Returns:
tuple: (清理后的提示文本, 图片数量, 比例)
"""
# 默认值
n = 1
aspect_ratio = default_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
return prompt, n, aspect_ratio
def extract_image_urls_from_markdown(text: str) -> List[str]:
"""
从Markdown文本中提取图片URL
Args:
text: Markdown文本
Returns:
List[str]: 图片URL列表
"""
pattern = IMAGE_URL_PATTERN
matches = re.findall(pattern, text)
return [match[1] for match in matches]
def is_valid_api_key(key: str) -> bool:
"""
检查API密钥格式是否有效
Args:
key: API密钥
Returns:
bool: 如果密钥格式有效则返回True
"""
# 检查Gemini API密钥格式
if key.startswith('AIza'):
return len(key) >= 30
# 检查OpenAI API密钥格式
if key.startswith('sk-'):
return len(key) >= 30
return False

View File

@@ -1,5 +1,5 @@
import requests
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
from enum import Enum
from typing import Optional, Any