mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 22:31:31 +08:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60dca70fcd | ||
|
|
89b9f7919a | ||
|
|
a8dc98ab6a | ||
|
|
b3a057b6ba | ||
|
|
b14bb93d8f | ||
|
|
8ca62707ea | ||
|
|
21444ed6c7 | ||
|
|
ba292dbedd | ||
|
|
6ba58ce9d1 | ||
|
|
16f16a3ae9 | ||
|
|
26dcb64687 | ||
|
|
df88492113 | ||
|
|
851bb9c09b | ||
|
|
0cac178572 | ||
|
|
67c85c994a | ||
|
|
ee979dd568 | ||
|
|
e79a1ba56c | ||
|
|
016e6e06ee | ||
|
|
8779a5f0b3 | ||
|
|
89f2825ac7 | ||
|
|
985a12554d | ||
|
|
37a7a140fc | ||
|
|
28e67cc3fa | ||
|
|
d99a0bde93 | ||
|
|
cb5cd92041 | ||
|
|
0be85e9536 | ||
|
|
632dee38b3 | ||
|
|
16c28bf1ba | ||
|
|
71af1db330 | ||
|
|
fb523f4a2e | ||
|
|
40e5ffa5f4 |
17
.env.example
17
.env.example
@@ -1,8 +1,11 @@
|
|||||||
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
||||||
ALLOWED_TOKENS=["sk-123456"]
|
ALLOWED_TOKENS=["sk-123456"]
|
||||||
# AUTH_TOKEN=sk-123456
|
# AUTH_TOKEN=sk-123456
|
||||||
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
TEST_MODEL=gemini-1.5-flash
|
||||||
TOOLS_CODE_EXECUTION_ENABLED=true
|
IMAGE_MODELS=["gemini-2.0-flash-exp"]
|
||||||
|
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||||
|
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
||||||
|
TOOLS_CODE_EXECUTION_ENABLED=false
|
||||||
SHOW_SEARCH_LINK=true
|
SHOW_SEARCH_LINK=true
|
||||||
SHOW_THINKING_PROCESS=true
|
SHOW_THINKING_PROCESS=true
|
||||||
BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||||
@@ -12,4 +15,14 @@ PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
|
|||||||
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||||
UPLOAD_PROVIDER=smms
|
UPLOAD_PROVIDER=smms
|
||||||
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||||
|
PICGO_API_KEY=xxxx
|
||||||
|
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||||
|
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||||
|
##########################################################################
|
||||||
|
#########################stream_optimizer 相关配置########################
|
||||||
|
STREAM_MIN_DELAY=0.016
|
||||||
|
STREAM_MAX_DELAY=0.024
|
||||||
|
STREAM_SHORT_TEXT_THRESHOLD=10
|
||||||
|
STREAM_LONG_TEXT_THRESHOLD=50
|
||||||
|
STREAM_CHUNK_SIZE=5
|
||||||
##########################################################################
|
##########################################################################
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ COPY ./app /app/app
|
|||||||
ENV API_KEYS='["your_api_key_1"]'
|
ENV API_KEYS='["your_api_key_1"]'
|
||||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||||
ENV TOOLS_CODE_EXECUTION_ENABLED=true
|
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||||
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
|
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
||||||
|
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|||||||
71
README.md
71
README.md
@@ -64,18 +64,31 @@
|
|||||||
AUTH_TOKEN="" # 超级管理员token,具有所有权限,默认使用 ALLOWED_TOKENS 的第一个
|
AUTH_TOKEN="" # 超级管理员token,具有所有权限,默认使用 ALLOWED_TOKENS 的第一个
|
||||||
|
|
||||||
# 模型功能配置
|
# 模型功能配置
|
||||||
MODEL_SEARCH=["gemini-2.0-flash-exp"] # 支持搜索功能的模型列表
|
TEST_MODEL="gemini-1.5-flash" # 用于测试密钥是否可用的模型名
|
||||||
|
SEARCH_MODELS=["gemini-2.0-flash-exp"] # 支持搜索功能的模型列表
|
||||||
|
IMAGE_MODELS=["gemini-2.0-flash-exp"] # 支持绘图功能的模型列表
|
||||||
TOOLS_CODE_EXECUTION_ENABLED=false # 是否启用代码执行工具,默认false
|
TOOLS_CODE_EXECUTION_ENABLED=false # 是否启用代码执行工具,默认false
|
||||||
SHOW_SEARCH_LINK=true # 是否在响应中显示搜索结果链接,默认true
|
SHOW_SEARCH_LINK=true # 是否在响应中显示搜索结果链接,默认true
|
||||||
SHOW_THINKING_PROCESS=true # 是否显示模型思考过程,默认true
|
SHOW_THINKING_PROCESS=true # 是否显示模型思考过程,默认true
|
||||||
|
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"] # 被禁用的模型列表
|
||||||
|
|
||||||
# 图片生成配置
|
# 图片生成配置
|
||||||
PAID_KEY="your-paid-api-key" # 付费版API Key,用于图片生成等高级功能
|
PAID_KEY="your-paid-api-key" # 付费版API Key,用于图片生成等高级功能
|
||||||
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型,默认使用imagen-3.0
|
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型,默认使用imagen-3.0
|
||||||
|
|
||||||
# 图片上传配置
|
# 图片上传配置
|
||||||
UPLOAD_PROVIDER="smms" # 图片上传提供商,目前支持smms
|
UPLOAD_PROVIDER="smms" # 图片上传提供商,目前支持smms、picgo、cloudflare_imgbed
|
||||||
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
|
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
|
||||||
|
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
|
||||||
|
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
|
||||||
|
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key,可在项目后台设置,若无鉴权则可直接置空。
|
||||||
|
|
||||||
|
# stream_optimizer 相关配置
|
||||||
|
STREAM_MIN_DELAY=0.016
|
||||||
|
STREAM_MAX_DELAY=0.024
|
||||||
|
STREAM_SHORT_TEXT_THRESHOLD=10
|
||||||
|
STREAM_LONG_TEXT_THRESHOLD=50
|
||||||
|
STREAM_CHUNK_SIZE=5
|
||||||
```
|
```
|
||||||
|
|
||||||
### 配置说明
|
### 配置说明
|
||||||
@@ -105,9 +118,17 @@
|
|||||||
|
|
||||||
#### 模型功能配置
|
#### 模型功能配置
|
||||||
|
|
||||||
- `MODEL_SEARCH`: 搜索功能支持的模型
|
- `TEST_MODEL`: 用于测试密钥可用性的模型
|
||||||
|
- 默认值: `"gemini-1.5-flash"`
|
||||||
|
- `SEARCH_MODELS`: 搜索功能支持的模型
|
||||||
- 默认值: `["gemini-2.0-flash-exp"]`
|
- 默认值: `["gemini-2.0-flash-exp"]`
|
||||||
- 说明: 仅列表中的模型可使用搜索功能
|
- 说明: 仅列表中的模型可使用搜索功能
|
||||||
|
- `IMAGE_MODELS`: 绘图功能支持的模型
|
||||||
|
- 默认值: `["gemini-2.0-flash-exp"]`
|
||||||
|
- 说明: 仅列表中的模型可使用绘图功能
|
||||||
|
- `FILTERED_MODELS`: 被禁用的模型列表
|
||||||
|
- 默认值: `["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]`
|
||||||
|
- 说明: 列表中的模型将被禁用
|
||||||
- `TOOLS_CODE_EXECUTION_ENABLED`: 代码执行功能
|
- `TOOLS_CODE_EXECUTION_ENABLED`: 代码执行功能
|
||||||
- 默认值: `false`
|
- 默认值: `false`
|
||||||
- 安全提示: 生产环境建议禁用
|
- 安全提示: 生产环境建议禁用
|
||||||
@@ -131,10 +152,44 @@
|
|||||||
|
|
||||||
- `UPLOAD_PROVIDER`: 图片上传服务提供商
|
- `UPLOAD_PROVIDER`: 图片上传服务提供商
|
||||||
- 默认值: `smms`
|
- 默认值: `smms`
|
||||||
- 说明: 目前支持 SM.MS 图床
|
- 可选值: `smms`, `picgo`, `cloudflare_imgbed`
|
||||||
|
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
|
||||||
|
|
||||||
- `SMMS_SECRET_TOKEN`: SM.MS API Token
|
- `SMMS_SECRET_TOKEN`: SM.MS API Token
|
||||||
- 用途: 用于图片上传到 SM.MS 图床
|
- 用途: 用于图片上传到 SM.MS 图床的身份验证。
|
||||||
- 获取方式: 需要在 SM.MS 官网注册并获取
|
- 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取。
|
||||||
|
|
||||||
|
- `PICGO_API_KEY`: PicGo API Key
|
||||||
|
- 用途: 用于图片上传到 PicGo 图床的身份验证。
|
||||||
|
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
|
||||||
|
|
||||||
|
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
|
||||||
|
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
|
||||||
|
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
|
||||||
|
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
|
||||||
|
|
||||||
|
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
|
||||||
|
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
|
||||||
|
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权,则留空即可。
|
||||||
|
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
|
||||||
|
|
||||||
|
#### 流式输出优化配置
|
||||||
|
|
||||||
|
- `STREAM_MIN_DELAY`: 最小延迟时间
|
||||||
|
- 默认值: `0.016`(秒)
|
||||||
|
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
|
||||||
|
- `STREAM_MAX_DELAY`: 最大延迟时间
|
||||||
|
- 默认值: `0.024`(秒)
|
||||||
|
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
|
||||||
|
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
|
||||||
|
- 默认值: `10`(字符)
|
||||||
|
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
|
||||||
|
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
|
||||||
|
- 默认值: `50`(字符)
|
||||||
|
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
|
||||||
|
- `STREAM_CHUNK_SIZE`: 长文本分块大小
|
||||||
|
- 默认值: `5`(字符)
|
||||||
|
- 说明: 长文本分块输出时,每个块的大小
|
||||||
|
|
||||||
### ▶️ 运行
|
### ▶️ 运行
|
||||||
|
|
||||||
@@ -208,7 +263,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
"content": "你好"
|
"content": "你好"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"model": "gemini-1.5-flash-002",
|
"model": "gemini-1.5-flash",
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"tools": [],
|
"tools": [],
|
||||||
@@ -221,7 +276,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
|
|
||||||
- `messages`: 消息列表,格式与 OpenAI API 相同
|
- `messages`: 消息列表,格式与 OpenAI API 相同
|
||||||
- `model`: 模型名称,支持所有Gemini模型,包括:
|
- `model`: 模型名称,支持所有Gemini模型,包括:
|
||||||
- `gemini-1.5-flash-002`: 快速响应模型
|
- `gemini-1.5-flash`: 快速响应模型
|
||||||
- `gemini-2.0-flash-exp`: 实验性快速响应模型
|
- `gemini-2.0-flash-exp`: 实验性快速响应模型
|
||||||
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
|
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
|
||||||
- `stream`: 是否开启流式响应,`true` 或 `false`
|
- `stream`: 是否开启流式响应,`true` 或 `false`
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logger import get_gemini_logger
|
|
||||||
from app.core.security import SecurityService
|
|
||||||
from app.schemas.gemini_models import GeminiContent, GeminiRequest
|
|
||||||
from app.services.gemini_chat_service import GeminiChatService
|
|
||||||
from app.services.key_manager import KeyManager, get_key_manager_instance
|
|
||||||
from app.services.model_service import ModelService
|
|
||||||
from app.services.chat.retry_handler import RetryHandler
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/gemini/v1beta")
|
|
||||||
router_v1beta = APIRouter(prefix="/v1beta")
|
|
||||||
logger = get_gemini_logger()
|
|
||||||
|
|
||||||
# 初始化服务
|
|
||||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
|
||||||
|
|
||||||
async def get_key_manager():
|
|
||||||
return await get_key_manager_instance()
|
|
||||||
|
|
||||||
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
|
||||||
return await key_manager.get_next_working_key()
|
|
||||||
|
|
||||||
model_service = ModelService(settings.MODEL_SEARCH)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/models")
|
|
||||||
@router_v1beta.get("/models")
|
|
||||||
async def list_models(_=Depends(security_service.verify_key),
|
|
||||||
key_manager: KeyManager = Depends(get_key_manager)):
|
|
||||||
"""获取可用的Gemini模型列表"""
|
|
||||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
|
||||||
logger.info("Handling Gemini models list request")
|
|
||||||
api_key = await key_manager.get_next_working_key()
|
|
||||||
logger.info(f"Using API key: {api_key}")
|
|
||||||
models_json = model_service.get_gemini_models(api_key)
|
|
||||||
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
|
|
||||||
"displayName": "Gemini 2.0 Flash Search Experimental",
|
|
||||||
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
|
|
||||||
"outputTokenLimit": 8192,
|
|
||||||
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
|
|
||||||
"topP": 0.95, "topK": 64, "maxTemperature": 2})
|
|
||||||
return models_json
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/models/{model_name}:generateContent")
|
|
||||||
@router_v1beta.post("/models/{model_name}:generateContent")
|
|
||||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
|
||||||
async def generate_content(
|
|
||||||
model_name: str,
|
|
||||||
request: GeminiRequest,
|
|
||||||
_=Depends(security_service.verify_goog_api_key),
|
|
||||||
api_key: str = Depends(get_next_working_key_wrapper),
|
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
|
||||||
):
|
|
||||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
|
||||||
"""非流式生成内容"""
|
|
||||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
|
||||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
|
||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
|
||||||
logger.info(f"Using API key: {api_key}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await chat_service.generate_content(
|
|
||||||
model=model_name,
|
|
||||||
request=request,
|
|
||||||
api_key=api_key
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/models/{model_name}:streamGenerateContent")
|
|
||||||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
|
||||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
|
||||||
async def stream_generate_content(
|
|
||||||
model_name: str,
|
|
||||||
request: GeminiRequest,
|
|
||||||
_=Depends(security_service.verify_goog_api_key),
|
|
||||||
api_key: str = Depends(get_next_working_key_wrapper),
|
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
|
||||||
):
|
|
||||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
|
||||||
"""流式生成内容"""
|
|
||||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
|
||||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
|
||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
|
||||||
logger.info(f"Using API key: {api_key}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_stream = chat_service.stream_generate_content(
|
|
||||||
model=model_name,
|
|
||||||
request=request,
|
|
||||||
api_key=api_key
|
|
||||||
)
|
|
||||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Streaming request failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify-key/{api_key}")
|
|
||||||
async def verify_key(api_key: str):
|
|
||||||
key_manager = await get_key_manager()
|
|
||||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
|
||||||
"""验证Gemini API密钥的有效性"""
|
|
||||||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
|
||||||
logger.info("Verifying API key validity")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用generate_content接口测试key的有效性
|
|
||||||
gemini_requset = GeminiRequest(
|
|
||||||
contents=[
|
|
||||||
GeminiContent(
|
|
||||||
role="user",
|
|
||||||
parts=[{"text": "hi"}]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
|
|
||||||
if response:
|
|
||||||
return JSONResponse({"status": "valid"})
|
|
||||||
return JSONResponse({"status": "invalid"})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Key verification failed: {str(e)}")
|
|
||||||
return JSONResponse({"status": "invalid", "error": str(e)})
|
|
||||||
55
app/config/config.py
Normal file
55
app/config/config.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
应用程序配置模块
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""应用程序配置"""
|
||||||
|
# API相关配置
|
||||||
|
API_KEYS: List[str]
|
||||||
|
ALLOWED_TOKENS: List[str]
|
||||||
|
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||||
|
AUTH_TOKEN: str = ""
|
||||||
|
MAX_FAILURES: int = 3
|
||||||
|
TEST_MODEL: str = DEFAULT_MODEL
|
||||||
|
|
||||||
|
# 模型相关配置
|
||||||
|
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||||
|
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||||
|
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||||
|
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||||
|
SHOW_SEARCH_LINK: bool = True
|
||||||
|
SHOW_THINKING_PROCESS: bool = True
|
||||||
|
|
||||||
|
# 图像生成相关配置
|
||||||
|
PAID_KEY: str = ""
|
||||||
|
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
||||||
|
UPLOAD_PROVIDER: str = "smms"
|
||||||
|
SMMS_SECRET_TOKEN: str = ""
|
||||||
|
PICGO_API_KEY: str = ""
|
||||||
|
CLOUDFLARE_IMGBED_URL: str = ""
|
||||||
|
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||||
|
|
||||||
|
# 流式输出优化器配置
|
||||||
|
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
|
||||||
|
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
|
||||||
|
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||||
|
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
|
||||||
|
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# 设置默认AUTH_TOKEN(如果未提供)
|
||||||
|
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
|
||||||
|
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = Settings()
|
||||||
71
app/core/application.py
Normal file
71
app/core/application.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
应用程序工厂模块,负责创建和配置FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.log.logger import get_application_logger
|
||||||
|
from app.middleware.middleware import setup_middlewares
|
||||||
|
from app.exception.exceptions import setup_exception_handlers
|
||||||
|
from app.router.routes import setup_routers
|
||||||
|
from app.service.key.key_manager import get_key_manager_instance
|
||||||
|
from app.core.initialization import initialize_app
|
||||||
|
|
||||||
|
logger = get_application_logger()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
应用程序生命周期管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用实例
|
||||||
|
"""
|
||||||
|
# 启动事件
|
||||||
|
logger.info("Application starting up...")
|
||||||
|
try:
|
||||||
|
# 初始化KeyManager
|
||||||
|
await get_key_manager_instance(settings.API_KEYS)
|
||||||
|
logger.info("KeyManager initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize KeyManager: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield # 应用程序运行期间
|
||||||
|
|
||||||
|
# 关闭事件
|
||||||
|
logger.info("Application shutting down...")
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""
|
||||||
|
创建并配置FastAPI应用程序实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FastAPI: 配置好的FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
# 初始化应用程序
|
||||||
|
initialize_app()
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="Gemini Balance API",
|
||||||
|
description="Gemini API代理服务,支持负载均衡和密钥管理",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置静态文件
|
||||||
|
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||||
|
|
||||||
|
# 配置中间件
|
||||||
|
setup_middlewares(app)
|
||||||
|
|
||||||
|
# 配置异常处理器
|
||||||
|
setup_exception_handlers(app)
|
||||||
|
|
||||||
|
# 配置路由
|
||||||
|
setup_routers(app)
|
||||||
|
|
||||||
|
return app
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
from pydantic_settings import BaseSettings
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
API_KEYS: List[str]
|
|
||||||
ALLOWED_TOKENS: List[str]
|
|
||||||
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
|
||||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
|
||||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
|
||||||
SHOW_SEARCH_LINK: bool = True
|
|
||||||
SHOW_THINKING_PROCESS: bool = True
|
|
||||||
AUTH_TOKEN: str = ""
|
|
||||||
MAX_FAILURES: int = 3
|
|
||||||
PAID_KEY: str = ""
|
|
||||||
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
|
|
||||||
UPLOAD_PROVIDER: str = "smms"
|
|
||||||
SMMS_SECRET_TOKEN: str = ""
|
|
||||||
TEST_MODEL: str = "gemini-1.5-flash"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
if not self.AUTH_TOKEN:
|
|
||||||
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file = ".env"
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
41
app/core/constants.py
Normal file
41
app/core/constants.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
常量定义模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
# API相关常量
|
||||||
|
API_VERSION = "v1beta"
|
||||||
|
DEFAULT_TIMEOUT = 300 # 秒
|
||||||
|
|
||||||
|
# 模型相关常量
|
||||||
|
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||||
|
DEFAULT_MODEL = "gemini-1.5-flash"
|
||||||
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
|
DEFAULT_MAX_TOKENS = 8192
|
||||||
|
DEFAULT_TOP_P = 0.9
|
||||||
|
DEFAULT_TOP_K = 40
|
||||||
|
DEFAULT_FILTER_MODELS = [
|
||||||
|
"gemini-1.0-pro-vision-latest",
|
||||||
|
"gemini-pro-vision",
|
||||||
|
"chat-bison-001",
|
||||||
|
"text-bison-001",
|
||||||
|
"embedding-gecko-001"
|
||||||
|
]
|
||||||
|
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||||
|
|
||||||
|
# 图像生成相关常量
|
||||||
|
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||||
|
|
||||||
|
# 上传提供商
|
||||||
|
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
||||||
|
DEFAULT_UPLOAD_PROVIDER = "smms"
|
||||||
|
|
||||||
|
# 流式输出相关常量
|
||||||
|
DEFAULT_STREAM_MIN_DELAY = 0.016
|
||||||
|
DEFAULT_STREAM_MAX_DELAY = 0.024
|
||||||
|
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
|
||||||
|
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||||
|
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||||
|
|
||||||
|
# 正则表达式模式
|
||||||
|
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||||
|
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||||
40
app/core/initialization.py
Normal file
40
app/core/initialization.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
应用程序初始化模块
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from app.log.logger import get_initialization_logger
|
||||||
|
|
||||||
|
logger = get_initialization_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_directories_exist(directories: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
确保指定的目录存在,如果不存在则创建
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directories: 要确保存在的目录列表
|
||||||
|
"""
|
||||||
|
for directory in directories:
|
||||||
|
try:
|
||||||
|
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Ensured directory exists: {directory}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create directory {directory}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_app() -> None:
|
||||||
|
"""
|
||||||
|
初始化应用程序,确保所需的目录和文件都存在
|
||||||
|
"""
|
||||||
|
# 确保必要的目录存在
|
||||||
|
required_directories = [
|
||||||
|
"app/static/css",
|
||||||
|
"app/static/js",
|
||||||
|
"app/static/icons",
|
||||||
|
"app/templates",
|
||||||
|
]
|
||||||
|
|
||||||
|
ensure_directories_exist(required_directories)
|
||||||
|
logger.info("Application initialization completed")
|
||||||
@@ -1,13 +1,17 @@
|
|||||||
from fastapi import HTTPException, Header
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from app.core.logger import get_security_logger
|
|
||||||
from app.core.config import settings
|
from fastapi import Header, HTTPException
|
||||||
|
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.log.logger import get_security_logger
|
||||||
|
|
||||||
logger = get_security_logger()
|
logger = get_security_logger()
|
||||||
|
|
||||||
|
|
||||||
def verify_auth_token(token: str) -> bool:
|
def verify_auth_token(token: str) -> bool:
|
||||||
return token == settings.AUTH_TOKEN
|
return token == settings.AUTH_TOKEN
|
||||||
|
|
||||||
|
|
||||||
class SecurityService:
|
class SecurityService:
|
||||||
def __init__(self, allowed_tokens: list, auth_token: str):
|
def __init__(self, allowed_tokens: list, auth_token: str):
|
||||||
self.allowed_tokens = allowed_tokens
|
self.allowed_tokens = allowed_tokens
|
||||||
@@ -20,7 +24,7 @@ class SecurityService:
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
async def verify_authorization(
|
async def verify_authorization(
|
||||||
self, authorization: Optional[str] = Header(None)
|
self, authorization: Optional[str] = Header(None)
|
||||||
) -> str:
|
) -> str:
|
||||||
if not authorization:
|
if not authorization:
|
||||||
logger.error("Missing Authorization header")
|
logger.error("Missing Authorization header")
|
||||||
@@ -39,19 +43,26 @@ class SecurityService:
|
|||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
|
async def verify_goog_api_key(
|
||||||
|
self, x_goog_api_key: Optional[str] = Header(None)
|
||||||
|
) -> str:
|
||||||
"""验证Google API Key"""
|
"""验证Google API Key"""
|
||||||
if not x_goog_api_key:
|
if not x_goog_api_key:
|
||||||
logger.error("Missing x-goog-api-key header")
|
logger.error("Missing x-goog-api-key header")
|
||||||
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
||||||
|
|
||||||
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
|
if (
|
||||||
|
x_goog_api_key not in self.allowed_tokens
|
||||||
|
and x_goog_api_key != self.auth_token
|
||||||
|
):
|
||||||
logger.error("Invalid x-goog-api-key")
|
logger.error("Invalid x-goog-api-key")
|
||||||
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
||||||
|
|
||||||
return x_goog_api_key
|
return x_goog_api_key
|
||||||
|
|
||||||
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
|
async def verify_auth_token(
|
||||||
|
self, authorization: Optional[str] = Header(None)
|
||||||
|
) -> str:
|
||||||
if not authorization:
|
if not authorization:
|
||||||
logger.error("Missing auth_token header")
|
logger.error("Missing auth_token header")
|
||||||
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
||||||
|
|||||||
@@ -1,163 +0,0 @@
|
|||||||
import requests
|
|
||||||
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Any
|
|
||||||
|
|
||||||
class UploadErrorType(Enum):
|
|
||||||
"""上传错误类型枚举"""
|
|
||||||
NETWORK_ERROR = "network_error" # 网络请求错误
|
|
||||||
AUTH_ERROR = "auth_error" # 认证错误
|
|
||||||
INVALID_FILE = "invalid_file" # 无效文件
|
|
||||||
SERVER_ERROR = "server_error" # 服务器错误
|
|
||||||
PARSE_ERROR = "parse_error" # 响应解析错误
|
|
||||||
UNKNOWN = "unknown" # 未知错误
|
|
||||||
|
|
||||||
|
|
||||||
class UploadError(Exception):
|
|
||||||
"""图片上传错误异常类"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message: str,
|
|
||||||
error_type: UploadErrorType = UploadErrorType.UNKNOWN,
|
|
||||||
status_code: Optional[int] = None,
|
|
||||||
details: Optional[dict] = None,
|
|
||||||
original_error: Optional[Exception] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
初始化上传错误异常
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 错误消息
|
|
||||||
error_type: 错误类型
|
|
||||||
status_code: HTTP状态码
|
|
||||||
details: 详细错误信息
|
|
||||||
original_error: 原始异常
|
|
||||||
"""
|
|
||||||
self.message = message
|
|
||||||
self.error_type = error_type
|
|
||||||
self.status_code = status_code
|
|
||||||
self.details = details or {}
|
|
||||||
self.original_error = original_error
|
|
||||||
|
|
||||||
# 构建完整错误信息
|
|
||||||
full_message = f"[{error_type.value}] {message}"
|
|
||||||
if status_code:
|
|
||||||
full_message = f"{full_message} (Status: {status_code})"
|
|
||||||
if details:
|
|
||||||
full_message = f"{full_message} - Details: {details}"
|
|
||||||
|
|
||||||
super().__init__(full_message)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError":
|
|
||||||
"""
|
|
||||||
从HTTP响应创建错误实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: HTTP响应对象
|
|
||||||
message: 自定义错误消息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
error_data = response.json()
|
|
||||||
details = error_data.get("data", {})
|
|
||||||
return cls(
|
|
||||||
message=message or error_data.get("message", "Unknown error"),
|
|
||||||
error_type=UploadErrorType.SERVER_ERROR,
|
|
||||||
status_code=response.status_code,
|
|
||||||
details=details
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return cls(
|
|
||||||
message=message or "Failed to parse error response",
|
|
||||||
error_type=UploadErrorType.PARSE_ERROR,
|
|
||||||
status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SmMsUploader(ImageUploader):
|
|
||||||
API_URL = "https://sm.ms/api/v2/upload"
|
|
||||||
|
|
||||||
def __init__(self, api_key: str):
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
|
||||||
try:
|
|
||||||
# 准备请求头
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Basic {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 准备文件数据
|
|
||||||
files = {
|
|
||||||
"smfile": (filename, file, "image/png")
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求
|
|
||||||
response = requests.post(
|
|
||||||
self.API_URL,
|
|
||||||
headers=headers,
|
|
||||||
files=files
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查响应状态
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# 解析响应
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
# 验证上传是否成功
|
|
||||||
if not result.get("success"):
|
|
||||||
raise UploadError(result.get("message", "Upload failed"))
|
|
||||||
|
|
||||||
# 转换为统一格式
|
|
||||||
data = result["data"]
|
|
||||||
image_metadata = ImageMetadata(
|
|
||||||
width=data["width"],
|
|
||||||
height=data["height"],
|
|
||||||
filename=data["filename"],
|
|
||||||
size=data["size"],
|
|
||||||
url=data["url"],
|
|
||||||
delete_url=data["delete"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return UploadResponse(
|
|
||||||
success=True,
|
|
||||||
code="success",
|
|
||||||
message="Upload success",
|
|
||||||
data=image_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
# 处理网络请求相关错误
|
|
||||||
raise UploadError(f"Upload request failed: {str(e)}")
|
|
||||||
except (KeyError, ValueError) as e:
|
|
||||||
# 处理响应解析错误
|
|
||||||
raise UploadError(f"Invalid response format: {str(e)}")
|
|
||||||
except Exception as e:
|
|
||||||
# 处理其他未预期的错误
|
|
||||||
raise UploadError(f"Upload failed: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
class QiniuUploader(ImageUploader):
|
|
||||||
def __init__(self, access_key: str, secret_key: str):
|
|
||||||
self.access_key = access_key
|
|
||||||
self.secret_key = secret_key
|
|
||||||
|
|
||||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
|
||||||
# 实现七牛云的具体上传逻辑
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageUploaderFactory:
|
|
||||||
@staticmethod
|
|
||||||
def create(provider: str, **credentials) -> ImageUploader:
|
|
||||||
if provider == "smms":
|
|
||||||
return SmMsUploader(credentials["api_key"])
|
|
||||||
elif provider == "qiniu":
|
|
||||||
return QiniuUploader(
|
|
||||||
credentials["access_key"],
|
|
||||||
credentials["secret_key"]
|
|
||||||
)
|
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class GeminiContent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiRequest(BaseModel):
|
class GeminiRequest(BaseModel):
|
||||||
contents: List[GeminiContent]
|
contents: List[GeminiContent] = []
|
||||||
tools: Optional[List[Dict[str, Any]]] = []
|
tools: Optional[List[Dict[str, Any]]] = []
|
||||||
safetySettings: Optional[List[SafetySetting]] = None
|
safetySettings: Optional[List[SafetySetting]] = None
|
||||||
generationConfig: Optional[GenerationConfig] = None
|
generationConfig: Optional[GenerationConfig] = None
|
||||||
@@ -1,17 +1,19 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from app.core.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
messages: List[dict]
|
messages: List[dict]
|
||||||
model: str = "gemini-1.5-flash-002"
|
model: str = DEFAULT_MODEL
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tools: Optional[List[dict]] = []
|
tools: Optional[List[dict]] = []
|
||||||
max_tokens: Optional[int] = 8192
|
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
|
||||||
|
top_p: Optional[float] = DEFAULT_TOP_P
|
||||||
|
top_k: Optional[int] = DEFAULT_TOP_K
|
||||||
stop: Optional[List[str]] = []
|
stop: Optional[List[str]] = []
|
||||||
top_p: Optional[float] = 0.9
|
|
||||||
top_k: Optional[int] = 40
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
140
app/exception/exceptions.py
Normal file
140
app/exception/exceptions.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
|
||||||
|
from app.log.logger import get_exceptions_logger
|
||||||
|
|
||||||
|
logger = get_exceptions_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class APIError(Exception):
|
||||||
|
"""API错误基类"""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, detail: str, error_code: str = None):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.detail = detail
|
||||||
|
self.error_code = error_code or "api_error"
|
||||||
|
super().__init__(self.detail)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(APIError):
|
||||||
|
"""认证错误"""
|
||||||
|
|
||||||
|
def __init__(self, detail: str = "Authentication failed"):
|
||||||
|
super().__init__(
|
||||||
|
status_code=401, detail=detail, error_code="authentication_error"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(APIError):
|
||||||
|
"""授权错误"""
|
||||||
|
|
||||||
|
def __init__(self, detail: str = "Not authorized to access this resource"):
|
||||||
|
super().__init__(
|
||||||
|
status_code=403, detail=detail, error_code="authorization_error"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceNotFoundError(APIError):
|
||||||
|
"""资源未找到错误"""
|
||||||
|
|
||||||
|
def __init__(self, detail: str = "Resource not found"):
|
||||||
|
super().__init__(
|
||||||
|
status_code=404, detail=detail, error_code="resource_not_found"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotSupportedError(APIError):
|
||||||
|
"""模型不支持错误"""
|
||||||
|
|
||||||
|
def __init__(self, model: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Model {model} is not supported",
|
||||||
|
error_code="model_not_supported",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyError(APIError):
|
||||||
|
"""API密钥错误"""
|
||||||
|
|
||||||
|
def __init__(self, detail: str = "Invalid or expired API key"):
|
||||||
|
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceUnavailableError(APIError):
|
||||||
|
"""服务不可用错误"""
|
||||||
|
|
||||||
|
def __init__(self, detail: str = "Service temporarily unavailable"):
|
||||||
|
super().__init__(
|
||||||
|
status_code=503, detail=detail, error_code="service_unavailable"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_exception_handlers(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置应用程序的异常处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@app.exception_handler(APIError)
|
||||||
|
async def api_error_handler(request: Request, exc: APIError):
|
||||||
|
"""处理API错误"""
|
||||||
|
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={"error": {"code": exc.error_code, "message": exc.detail}},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(StarletteHTTPException)
|
||||||
|
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||||
|
"""处理HTTP异常"""
|
||||||
|
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={"error": {"code": "http_error", "message": exc.detail}},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(
|
||||||
|
request: Request, exc: RequestValidationError
|
||||||
|
):
|
||||||
|
"""处理请求验证错误"""
|
||||||
|
error_details = []
|
||||||
|
for error in exc.errors():
|
||||||
|
error_details.append(
|
||||||
|
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.error(f"Validation Error: {error_details}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "validation_error",
|
||||||
|
"message": "Request validation failed",
|
||||||
|
"details": error_details,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request: Request, exc: Exception):
|
||||||
|
"""处理通用异常"""
|
||||||
|
logger.exception(f"Unhandled Exception: {str(exc)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "internal_server_error",
|
||||||
|
"message": "An unexpected error occurred",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
171
app/handler/message_converter.py
Normal file
171
app/handler/message_converter.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# app/services/chat/message_converter.py
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
|
||||||
|
|
||||||
|
|
||||||
|
class MessageConverter(ABC):
|
||||||
|
"""消息转换器基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_mime_type_and_data(base64_string):
|
||||||
|
"""
|
||||||
|
从 base64 字符串中提取 MIME 类型和数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: (mime_type, encoded_data)
|
||||||
|
"""
|
||||||
|
# 检查字符串是否以 "data:" 格式开始
|
||||||
|
if base64_string.startswith('data:'):
|
||||||
|
# 提取 MIME 类型和数据
|
||||||
|
pattern = DATA_URL_PATTERN
|
||||||
|
match = re.match(pattern, base64_string)
|
||||||
|
if match:
|
||||||
|
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||||
|
encoded_data = match.group(2)
|
||||||
|
return mime_type, encoded_data
|
||||||
|
|
||||||
|
# 如果不是预期格式,假定它只是数据部分
|
||||||
|
return None, base64_string
|
||||||
|
|
||||||
|
def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||||
|
if image_url.startswith("data:image"):
|
||||||
|
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
||||||
|
return {
|
||||||
|
"inline_data": {
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"data": encoded_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_image_to_base64(url: str) -> str:
|
||||||
|
"""
|
||||||
|
将图片URL转换为base64编码
|
||||||
|
Args:
|
||||||
|
url: 图片URL
|
||||||
|
Returns:
|
||||||
|
str: base64编码的图片数据
|
||||||
|
"""
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# 将图片内容转换为base64
|
||||||
|
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||||
|
return img_data
|
||||||
|
else:
|
||||||
|
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
处理可能包含图片URL的文本,提取图片并转换为base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 可能包含图片URL的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
||||||
|
if img_url_match:
|
||||||
|
# 提取URL
|
||||||
|
img_url = img_url_match.group(2)
|
||||||
|
# 将URL对应的图片转换为base64
|
||||||
|
try:
|
||||||
|
base64_data = _convert_image_to_base64(img_url)
|
||||||
|
parts.append({
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
# 如果转换失败,回退到文本模式
|
||||||
|
parts.append({"text": text})
|
||||||
|
else:
|
||||||
|
# 没有图片URL,作为纯文本处理
|
||||||
|
parts.append({"text": text})
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMessageConverter(MessageConverter):
|
||||||
|
"""OpenAI消息格式转换器"""
|
||||||
|
|
||||||
|
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||||
|
converted_messages = []
|
||||||
|
system_instruction_parts = []
|
||||||
|
|
||||||
|
for idx, msg in enumerate(messages):
|
||||||
|
role = msg.get("role", "")
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
# 特别处理最后一个assistant的消息,按\n\n分割
|
||||||
|
if "content" in msg and isinstance(msg["content"], str) and msg["content"] and role == "assistant" and idx == len(messages) - 2:
|
||||||
|
# 按\n\n分割消息
|
||||||
|
content_parts = msg["content"].split("\n\n")
|
||||||
|
for part in content_parts:
|
||||||
|
if not part.strip(): # 跳过空内容
|
||||||
|
continue
|
||||||
|
# 处理可能包含图片的文本
|
||||||
|
parts.extend(_process_text_with_image(part))
|
||||||
|
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
|
||||||
|
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||||
|
parts.extend(_process_text_with_image(msg["content"]))
|
||||||
|
elif "content" in msg and isinstance(msg["content"], list):
|
||||||
|
for content in msg["content"]:
|
||||||
|
if isinstance(content, str) and content:
|
||||||
|
parts.append({"text": content})
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
if content["type"] == "text" and content["text"]:
|
||||||
|
parts.append({"text": content["text"]})
|
||||||
|
elif content["type"] == "image_url":
|
||||||
|
parts.append(_convert_image(content["image_url"]["url"]))
|
||||||
|
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||||
|
for tool_call in msg["tool_calls"]:
|
||||||
|
function_call = tool_call.get("function",{})
|
||||||
|
function_call["args"] = json.loads(function_call.get("arguments","{}"))
|
||||||
|
del function_call["arguments"]
|
||||||
|
parts.append({"functionCall": function_call})
|
||||||
|
|
||||||
|
if role not in SUPPORTED_ROLES:
|
||||||
|
if role == "tool":
|
||||||
|
role = "user"
|
||||||
|
else:
|
||||||
|
# 如果是最后一条消息,则认为是用户消息
|
||||||
|
if idx == len(messages) - 1:
|
||||||
|
role = "user"
|
||||||
|
else:
|
||||||
|
role = "model"
|
||||||
|
if parts:
|
||||||
|
if role == "system":
|
||||||
|
system_instruction_parts.extend(parts)
|
||||||
|
else:
|
||||||
|
converted_messages.append({"role": role, "parts": parts})
|
||||||
|
|
||||||
|
system_instruction = (
|
||||||
|
None
|
||||||
|
if not system_instruction_parts
|
||||||
|
else {
|
||||||
|
"role": "system",
|
||||||
|
"parts": system_instruction_parts,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return converted_messages, system_instruction
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
# app/services/chat/response_handler.py
|
# app/services/chat/response_handler.py
|
||||||
|
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
@@ -7,7 +8,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
|
from app.utils.uploader import ImageUploaderFactory
|
||||||
|
|
||||||
|
|
||||||
class ResponseHandler(ABC):
|
class ResponseHandler(ABC):
|
||||||
@@ -135,67 +137,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
|||||||
candidate = response["candidates"][0]
|
candidate = response["candidates"][0]
|
||||||
content = candidate.get("content", {})
|
content = candidate.get("content", {})
|
||||||
parts = content.get("parts", [])
|
parts = content.get("parts", [])
|
||||||
# if "thinking" in model:
|
if not parts:
|
||||||
# if settings.SHOW_THINKING_PROCESS:
|
return "", []
|
||||||
# if len(parts) == 1:
|
|
||||||
# if self.thinking_first:
|
|
||||||
# self.thinking_first = False
|
|
||||||
# self.thinking_status = True
|
|
||||||
# text = "> thinking\n\n" + parts[0].get("text")
|
|
||||||
# else:
|
|
||||||
# text = parts[0].get("text")
|
|
||||||
|
|
||||||
# if len(parts) == 2:
|
|
||||||
# self.thinking_status = False
|
|
||||||
# if self.thinking_first:
|
|
||||||
# self.thinking_first = False
|
|
||||||
# text = (
|
|
||||||
# "> thinking\n\n"
|
|
||||||
# + parts[0].get("text")
|
|
||||||
# + "\n\n---\n> output\n\n"
|
|
||||||
# + parts[1].get("text")
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# text = (
|
|
||||||
# parts[0].get("text")
|
|
||||||
# + "\n\n---\n> output\n\n"
|
|
||||||
# + parts[1].get("text")
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# if len(parts) == 1:
|
|
||||||
# if self.thinking_first:
|
|
||||||
# self.thinking_first = False
|
|
||||||
# self.thinking_status = True
|
|
||||||
# text = ""
|
|
||||||
# elif self.thinking_status:
|
|
||||||
# text = ""
|
|
||||||
# else:
|
|
||||||
# text = parts[0].get("text")
|
|
||||||
|
|
||||||
# if len(parts) == 2:
|
|
||||||
# self.thinking_status = False
|
|
||||||
# if self.thinking_first:
|
|
||||||
# self.thinking_first = False
|
|
||||||
# text = parts[1].get("text")
|
|
||||||
# else:
|
|
||||||
# text = parts[1].get("text")
|
|
||||||
# else:
|
|
||||||
# if "text" in parts[0]:
|
|
||||||
# text = parts[0].get("text")
|
|
||||||
# elif "executableCode" in parts[0]:
|
|
||||||
# text = _format_code_block(parts[0]["executableCode"])
|
|
||||||
# elif "codeExecution" in parts[0]:
|
|
||||||
# text = _format_code_block(parts[0]["codeExecution"])
|
|
||||||
# elif "executableCodeResult" in parts[0]:
|
|
||||||
# text = _format_execution_result(
|
|
||||||
# parts[0]["executableCodeResult"]
|
|
||||||
# )
|
|
||||||
# elif "codeExecutionResult" in parts[0]:
|
|
||||||
# text = _format_execution_result(
|
|
||||||
# parts[0]["codeExecutionResult"]
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# text = ""
|
|
||||||
if "text" in parts[0]:
|
if "text" in parts[0]:
|
||||||
text = parts[0].get("text")
|
text = parts[0].get("text")
|
||||||
elif "executableCode" in parts[0]:
|
elif "executableCode" in parts[0]:
|
||||||
@@ -210,6 +153,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
|||||||
text = _format_execution_result(
|
text = _format_execution_result(
|
||||||
parts[0]["codeExecutionResult"]
|
parts[0]["codeExecutionResult"]
|
||||||
)
|
)
|
||||||
|
elif "inlineData" in parts[0]:
|
||||||
|
text = _extract_image_data(parts[0])
|
||||||
else:
|
else:
|
||||||
text = ""
|
text = ""
|
||||||
text = _add_search_link_text(model, candidate, text)
|
text = _add_search_link_text(model, candidate, text)
|
||||||
@@ -235,14 +180,40 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
|||||||
text = candidate["content"]["parts"][0]["text"]
|
text = candidate["content"]["parts"][0]["text"]
|
||||||
else:
|
else:
|
||||||
text = ""
|
text = ""
|
||||||
for part in candidate["content"]["parts"]:
|
if "parts" in candidate["content"]:
|
||||||
text += part.get("text", "")
|
for part in candidate["content"]["parts"]:
|
||||||
|
if "text" in part:
|
||||||
|
text += part["text"]
|
||||||
|
elif "inlineData" in part:
|
||||||
|
text += _extract_image_data(part)
|
||||||
|
|
||||||
|
|
||||||
text = _add_search_link_text(model, candidate, text)
|
text = _add_search_link_text(model, candidate, text)
|
||||||
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
||||||
else:
|
else:
|
||||||
text = "暂无返回"
|
text = "暂无返回"
|
||||||
return text, tool_calls
|
return text, tool_calls
|
||||||
|
|
||||||
|
def _extract_image_data(part: dict) -> str:
|
||||||
|
image_uploader = None
|
||||||
|
if settings.UPLOAD_PROVIDER == "smms":
|
||||||
|
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||||
|
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||||
|
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
||||||
|
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||||
|
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
|
||||||
|
current_date = time.strftime("%Y/%m/%d")
|
||||||
|
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||||
|
base64_data = part["inlineData"]["data"]
|
||||||
|
#将base64_data转成bytes数组
|
||||||
|
bytes_data = base64.b64decode(base64_data)
|
||||||
|
upload_response = image_uploader.upload(bytes_data,filename)
|
||||||
|
if upload_response.success:
|
||||||
|
text = f"\n\n\n\n"
|
||||||
|
else:
|
||||||
|
text = ""
|
||||||
|
return text
|
||||||
|
|
||||||
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
||||||
"""提取工具调用信息"""
|
"""提取工具调用信息"""
|
||||||
if not parts or not isinstance(parts, list):
|
if not parts or not isinstance(parts, list):
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
# app/services/chat/retry_handler.py
|
# app/services/chat/retry_handler.py
|
||||||
|
|
||||||
from typing import TypeVar, Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from app.core.logger import get_retry_logger
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
T = TypeVar('T')
|
from app.log.logger import get_retry_logger
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
logger = get_retry_logger()
|
logger = get_retry_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -25,17 +26,21 @@ class RetryHandler:
|
|||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_exception = e
|
last_exception = e
|
||||||
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
|
logger.warning(
|
||||||
|
f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}"
|
||||||
|
)
|
||||||
|
|
||||||
# 从函数参数中获取 key_manager
|
# 从函数参数中获取 key_manager
|
||||||
key_manager = kwargs.get('key_manager')
|
key_manager = kwargs.get("key_manager")
|
||||||
if key_manager:
|
if key_manager:
|
||||||
old_key = kwargs.get(self.key_arg)
|
old_key = kwargs.get(self.key_arg)
|
||||||
new_key = await key_manager.handle_api_failure(old_key)
|
new_key = await key_manager.handle_api_failure(old_key)
|
||||||
kwargs[self.key_arg] = new_key
|
kwargs[self.key_arg] = new_key
|
||||||
logger.info(f"Switched to new API key: {new_key}")
|
logger.info(f"Switched to new API key: {new_key}")
|
||||||
|
|
||||||
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
|
logger.error(
|
||||||
|
f"All retry attempts failed, raising final exception: {str(last_exception)}"
|
||||||
|
)
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -2,8 +2,17 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
from typing import Any, List, AsyncGenerator, Callable
|
from typing import Any, AsyncGenerator, Callable, List
|
||||||
from app.core.logger import get_openai_logger, get_gemini_logger
|
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.core.constants import (
|
||||||
|
DEFAULT_STREAM_CHUNK_SIZE,
|
||||||
|
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||||
|
DEFAULT_STREAM_MAX_DELAY,
|
||||||
|
DEFAULT_STREAM_MIN_DELAY,
|
||||||
|
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||||
|
)
|
||||||
|
from app.log.logger import get_gemini_logger, get_openai_logger
|
||||||
|
|
||||||
logger_openai = get_openai_logger()
|
logger_openai = get_openai_logger()
|
||||||
logger_gemini = get_gemini_logger()
|
logger_gemini = get_gemini_logger()
|
||||||
@@ -11,19 +20,21 @@ logger_gemini = get_gemini_logger()
|
|||||||
|
|
||||||
class StreamOptimizer:
|
class StreamOptimizer:
|
||||||
"""流式输出优化器
|
"""流式输出优化器
|
||||||
|
|
||||||
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
|
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
logger=None,
|
self,
|
||||||
min_delay: float = 0.016,
|
logger=None,
|
||||||
max_delay: float = 0.024,
|
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
||||||
short_text_threshold: int = 10,
|
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
||||||
long_text_threshold: int = 100,
|
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||||
chunk_size: int = 10):
|
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||||
|
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
|
||||||
|
):
|
||||||
"""初始化流式输出优化器
|
"""初始化流式输出优化器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
logger: 日志记录器
|
logger: 日志记录器
|
||||||
min_delay: 最小延迟时间(秒)
|
min_delay: 最小延迟时间(秒)
|
||||||
@@ -38,13 +49,13 @@ class StreamOptimizer:
|
|||||||
self.short_text_threshold = short_text_threshold
|
self.short_text_threshold = short_text_threshold
|
||||||
self.long_text_threshold = long_text_threshold
|
self.long_text_threshold = long_text_threshold
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
def calculate_delay(self, text_length: int) -> float:
|
def calculate_delay(self, text_length: int) -> float:
|
||||||
"""根据文本长度计算延迟时间
|
"""根据文本长度计算延迟时间
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
text_length: 文本长度
|
text_length: 文本长度
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
延迟时间(秒)
|
延迟时间(秒)
|
||||||
"""
|
"""
|
||||||
@@ -57,42 +68,48 @@ class StreamOptimizer:
|
|||||||
else:
|
else:
|
||||||
# 中等长度文本使用线性插值计算延迟
|
# 中等长度文本使用线性插值计算延迟
|
||||||
# 使用对数函数使延迟变化更平滑
|
# 使用对数函数使延迟变化更平滑
|
||||||
ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold)
|
ratio = math.log(text_length / self.short_text_threshold) / math.log(
|
||||||
|
self.long_text_threshold / self.short_text_threshold
|
||||||
|
)
|
||||||
return self.max_delay - ratio * (self.max_delay - self.min_delay)
|
return self.max_delay - ratio * (self.max_delay - self.min_delay)
|
||||||
|
|
||||||
def split_text_into_chunks(self, text: str) -> List[str]:
|
def split_text_into_chunks(self, text: str) -> List[str]:
|
||||||
"""将文本分割成小块
|
"""将文本分割成小块
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
text: 要分割的文本
|
text: 要分割的文本
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
文本块列表
|
文本块列表
|
||||||
"""
|
"""
|
||||||
return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)]
|
return [
|
||||||
|
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
|
||||||
async def optimize_stream_output(self,
|
]
|
||||||
text: str,
|
|
||||||
create_response_chunk: Callable[[str], Any],
|
async def optimize_stream_output(
|
||||||
format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]:
|
self,
|
||||||
|
text: str,
|
||||||
|
create_response_chunk: Callable[[str], Any],
|
||||||
|
format_chunk: Callable[[Any], str],
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
"""优化流式输出
|
"""优化流式输出
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
text: 要输出的文本
|
text: 要输出的文本
|
||||||
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
|
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
|
||||||
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
|
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
异步生成器,生成格式化后的响应块
|
异步生成器,生成格式化后的响应块
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 计算智能延迟时间
|
# 计算智能延迟时间
|
||||||
delay = self.calculate_delay(len(text))
|
delay = self.calculate_delay(len(text))
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
|
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
|
||||||
|
|
||||||
# 根据文本长度决定输出方式
|
# 根据文本长度决定输出方式
|
||||||
if len(text) >= self.long_text_threshold:
|
if len(text) >= self.long_text_threshold:
|
||||||
# 长文本:分块输出
|
# 长文本:分块输出
|
||||||
@@ -112,5 +129,20 @@ class StreamOptimizer:
|
|||||||
|
|
||||||
|
|
||||||
# 创建默认的优化器实例,可以直接导入使用
|
# 创建默认的优化器实例,可以直接导入使用
|
||||||
openai_optimizer = StreamOptimizer(logger=logger_openai)
|
openai_optimizer = StreamOptimizer(
|
||||||
gemini_optimizer = StreamOptimizer(logger=logger_gemini)
|
logger=logger_openai,
|
||||||
|
min_delay=settings.STREAM_MIN_DELAY,
|
||||||
|
max_delay=settings.STREAM_MAX_DELAY,
|
||||||
|
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
||||||
|
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
||||||
|
chunk_size=settings.STREAM_CHUNK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
gemini_optimizer = StreamOptimizer(
|
||||||
|
logger=logger_gemini,
|
||||||
|
min_delay=settings.STREAM_MIN_DELAY,
|
||||||
|
max_delay=settings.STREAM_MAX_DELAY,
|
||||||
|
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
||||||
|
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
||||||
|
chunk_size=settings.STREAM_CHUNK_SIZE,
|
||||||
|
)
|
||||||
@@ -133,3 +133,23 @@ def get_retry_logger():
|
|||||||
|
|
||||||
def get_image_create_logger():
|
def get_image_create_logger():
|
||||||
return Logger.setup_logger("image_create")
|
return Logger.setup_logger("image_create")
|
||||||
|
|
||||||
|
|
||||||
|
def get_exceptions_logger():
|
||||||
|
return Logger.setup_logger("exceptions")
|
||||||
|
|
||||||
|
|
||||||
|
def get_application_logger():
|
||||||
|
return Logger.setup_logger("application")
|
||||||
|
|
||||||
|
|
||||||
|
def get_initialization_logger():
|
||||||
|
return Logger.setup_logger("initialization")
|
||||||
|
|
||||||
|
|
||||||
|
def get_middleware_logger():
|
||||||
|
return Logger.setup_logger("middleware")
|
||||||
|
|
||||||
|
|
||||||
|
def get_routes_logger():
|
||||||
|
return Logger.setup_logger("routes")
|
||||||
132
app/main.py
132
app/main.py
@@ -1,134 +1,18 @@
|
|||||||
from fastapi import FastAPI, Request
|
"""
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
应用程序入口模块
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
"""
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from app.core.logger import get_main_logger
|
|
||||||
from app.core.security import verify_auth_token
|
|
||||||
from app.services.key_manager import get_key_manager_instance
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
from app.api import gemini_routes, openai_routes
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from app.core.application import create_app
|
||||||
|
from app.log.logger import get_main_logger
|
||||||
|
|
||||||
|
# 创建应用程序实例
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = get_main_logger()
|
logger = get_main_logger()
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# 配置Jinja2模板
|
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
# 配置静态文件
|
|
||||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
|
||||||
|
|
||||||
# 创建 KeyManager 实例
|
|
||||||
key_manager = None
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
global key_manager
|
|
||||||
logger.info("Application starting up...")
|
|
||||||
try:
|
|
||||||
key_manager = await get_key_manager_instance(settings.API_KEYS)
|
|
||||||
logger.info("KeyManager initialized successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize KeyManager: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 添加中间件来处理未经身份验证的请求
|
|
||||||
@app.middleware("http")
|
|
||||||
async def auth_middleware(request: Request, call_next):
|
|
||||||
# 允许 gemini_routes 和 openai_routes 中的端点绕过身份验证
|
|
||||||
if (request.url.path not in ["/", "/auth"] and
|
|
||||||
not request.url.path.startswith("/static") and
|
|
||||||
not request.url.path.startswith("/gemini") and
|
|
||||||
not request.url.path.startswith("/v1") and
|
|
||||||
not request.url.path.startswith("/v1beta") and
|
|
||||||
not request.url.path.startswith("/health") and
|
|
||||||
not request.url.path.startswith("/hf")):
|
|
||||||
auth_token = request.cookies.get("auth_token")
|
|
||||||
if not auth_token or not verify_auth_token(auth_token):
|
|
||||||
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
|
||||||
return RedirectResponse(url="/")
|
|
||||||
logger.debug("Request authenticated successfully")
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
# 添加请求日志中间件
|
|
||||||
# app.add_middleware(RequestLoggingMiddleware)
|
|
||||||
|
|
||||||
# 配置CORS中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
|
|
||||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
|
||||||
expose_headers=["*"], # 允许前端访问的响应头
|
|
||||||
max_age=600, # 预检请求缓存时间(秒)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 包含所有路由
|
|
||||||
app.include_router(openai_routes.router)
|
|
||||||
app.include_router(gemini_routes.router)
|
|
||||||
app.include_router(gemini_routes.router_v1beta)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/", response_class=HTMLResponse)
|
|
||||||
async def auth_page(request: Request):
|
|
||||||
return templates.TemplateResponse("auth.html", {"request": request})
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/auth")
|
|
||||||
async def authenticate(request: Request):
|
|
||||||
try:
|
|
||||||
form = await request.form()
|
|
||||||
auth_token = form.get("auth_token")
|
|
||||||
if not auth_token:
|
|
||||||
logger.warning("Authentication attempt with empty token")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
|
|
||||||
if verify_auth_token(auth_token):
|
|
||||||
logger.info("Successful authentication")
|
|
||||||
response = RedirectResponse(url="/keys", status_code=302)
|
|
||||||
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
|
|
||||||
return response
|
|
||||||
logger.warning("Failed authentication attempt with invalid token")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Authentication error: {str(e)}")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
|
|
||||||
@app.get("/keys", response_class=HTMLResponse)
|
|
||||||
async def keys_page(request: Request):
|
|
||||||
try:
|
|
||||||
auth_token = request.cookies.get("auth_token")
|
|
||||||
if not auth_token or not verify_auth_token(auth_token):
|
|
||||||
logger.warning("Unauthorized access attempt to keys page")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
|
|
||||||
keys_status = await key_manager.get_keys_by_status()
|
|
||||||
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
|
||||||
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
|
|
||||||
return templates.TemplateResponse("keys_status.html", {
|
|
||||||
"request": request,
|
|
||||||
"valid_keys": keys_status["valid_keys"],
|
|
||||||
"invalid_keys": keys_status["invalid_keys"],
|
|
||||||
"total": total
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving keys status: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check(request: Request):
|
|
||||||
logger.info("Health check endpoint called")
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting application server...")
|
logger.info("Starting application server...")
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||||
|
|||||||
73
app/middleware/middleware.py
Normal file
73
app/middleware/middleware.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
中间件配置模块,负责设置和配置应用程序的中间件
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||||
|
from app.core.constants import API_VERSION
|
||||||
|
from app.core.security import verify_auth_token
|
||||||
|
from app.log.logger import get_middleware_logger
|
||||||
|
|
||||||
|
logger = get_middleware_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
认证中间件,处理未经身份验证的请求
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# 允许特定路径绕过身份验证
|
||||||
|
if (
|
||||||
|
request.url.path not in ["/", "/auth"]
|
||||||
|
and not request.url.path.startswith("/static")
|
||||||
|
and not request.url.path.startswith("/gemini")
|
||||||
|
and not request.url.path.startswith("/v1")
|
||||||
|
and not request.url.path.startswith(f"/{API_VERSION}")
|
||||||
|
and not request.url.path.startswith("/health")
|
||||||
|
and not request.url.path.startswith("/hf")
|
||||||
|
):
|
||||||
|
|
||||||
|
auth_token = request.cookies.get("auth_token")
|
||||||
|
if not auth_token or not verify_auth_token(auth_token):
|
||||||
|
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
||||||
|
return RedirectResponse(url="/")
|
||||||
|
logger.debug("Request authenticated successfully")
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def setup_middlewares(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置应用程序的中间件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
# 添加认证中间件
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
# 添加请求日志中间件(可选,默认注释掉)
|
||||||
|
# app.add_middleware(RequestLoggingMiddleware)
|
||||||
|
|
||||||
|
# 配置CORS中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=[
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"DELETE",
|
||||||
|
"OPTIONS",
|
||||||
|
], # 明确指定允许的HTTP方法
|
||||||
|
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||||
|
expose_headers=["*"], # 允许前端访问的响应头
|
||||||
|
max_age=600, # 预检请求缓存时间(秒)
|
||||||
|
)
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
import json
|
|
||||||
from app.core.logger import get_request_logger
|
from app.log.logger import get_request_logger
|
||||||
|
|
||||||
logger = get_request_logger()
|
logger = get_request_logger()
|
||||||
|
|
||||||
@@ -20,7 +22,9 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
# 尝试格式化JSON
|
# 尝试格式化JSON
|
||||||
try:
|
try:
|
||||||
formatted_body = json.loads(body_str)
|
formatted_body = json.loads(body_str)
|
||||||
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
|
logger.info(
|
||||||
|
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.info("Request body is not valid JSON.")
|
logger.info("Request body is not valid JSON.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
178
app/router/gemini_routes.py
Normal file
178
app/router/gemini_routes.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from copy import deepcopy
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.log.logger import get_gemini_logger
|
||||||
|
from app.core.security import SecurityService
|
||||||
|
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||||
|
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||||
|
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||||
|
from app.service.model.model_service import ModelService
|
||||||
|
from app.handler.retry_handler import RetryHandler
|
||||||
|
from app.core.constants import API_VERSION
|
||||||
|
|
||||||
|
# 路由设置
|
||||||
|
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||||
|
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||||||
|
logger = get_gemini_logger()
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||||
|
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_key_manager():
|
||||||
|
"""获取密钥管理器实例"""
|
||||||
|
return await get_key_manager_instance()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
||||||
|
"""获取下一个可用的API密钥"""
|
||||||
|
return await key_manager.get_next_working_key()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
@router_v1beta.get("/models")
|
||||||
|
async def list_models(
|
||||||
|
_=Depends(security_service.verify_key),
|
||||||
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
|
):
|
||||||
|
"""获取可用的Gemini模型列表"""
|
||||||
|
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||||
|
logger.info("Handling Gemini models list request")
|
||||||
|
|
||||||
|
api_key = await key_manager.get_next_working_key()
|
||||||
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
models_json = model_service.get_gemini_models(api_key)
|
||||||
|
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||||
|
|
||||||
|
# 添加搜索模型
|
||||||
|
if model_service.search_models:
|
||||||
|
for name in model_service.search_models:
|
||||||
|
model = model_mapping.get(name)
|
||||||
|
if not model:
|
||||||
|
continue
|
||||||
|
|
||||||
|
item = deepcopy(model)
|
||||||
|
item["name"] = f"models/{name}-search"
|
||||||
|
display_name = f'{item.get("displayName")} For Search'
|
||||||
|
item["displayName"] = display_name
|
||||||
|
item["description"] = display_name
|
||||||
|
|
||||||
|
models_json["models"].append(item)
|
||||||
|
|
||||||
|
# 添加图像生成模型
|
||||||
|
if model_service.image_models:
|
||||||
|
for name in model_service.image_models:
|
||||||
|
model = model_mapping.get(name)
|
||||||
|
if not model:
|
||||||
|
continue
|
||||||
|
|
||||||
|
item = deepcopy(model)
|
||||||
|
item["name"] = f"models/{name}-image"
|
||||||
|
display_name = f'{item.get("displayName")} For Image'
|
||||||
|
item["displayName"] = display_name
|
||||||
|
item["description"] = display_name
|
||||||
|
|
||||||
|
models_json["models"].append(item)
|
||||||
|
|
||||||
|
return models_json
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/models/{model_name}:generateContent")
|
||||||
|
@router_v1beta.post("/models/{model_name}:generateContent")
|
||||||
|
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||||
|
async def generate_content(
|
||||||
|
model_name: str,
|
||||||
|
request: GeminiRequest,
|
||||||
|
_=Depends(security_service.verify_goog_api_key),
|
||||||
|
api_key: str = Depends(get_next_working_key),
|
||||||
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
|
):
|
||||||
|
"""非流式生成内容"""
|
||||||
|
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||||
|
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||||
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
if not model_service.check_model_support(model_name):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||||
|
response = await chat_service.generate_content(
|
||||||
|
model=model_name,
|
||||||
|
request=request,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/models/{model_name}:streamGenerateContent")
|
||||||
|
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||||||
|
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||||
|
async def stream_generate_content(
|
||||||
|
model_name: str,
|
||||||
|
request: GeminiRequest,
|
||||||
|
_=Depends(security_service.verify_goog_api_key),
|
||||||
|
api_key: str = Depends(get_next_working_key),
|
||||||
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
|
):
|
||||||
|
"""流式生成内容"""
|
||||||
|
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||||
|
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||||
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
if not model_service.check_model_support(model_name):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||||
|
response_stream = chat_service.stream_generate_content(
|
||||||
|
model=model_name,
|
||||||
|
request=request,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Streaming request failed: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail="Streaming request failed") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/verify-key/{api_key}")
|
||||||
|
async def verify_key(api_key: str):
|
||||||
|
"""验证Gemini API密钥的有效性"""
|
||||||
|
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||||||
|
logger.info("Verifying API key validity")
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_manager = await get_key_manager()
|
||||||
|
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||||
|
|
||||||
|
# 使用generate_content接口测试key的有效性
|
||||||
|
gemini_request = GeminiRequest(
|
||||||
|
contents=[
|
||||||
|
GeminiContent(
|
||||||
|
role="user",
|
||||||
|
parts=[{"text": "hi"}]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await chat_service.generate_content(
|
||||||
|
settings.TEST_MODEL,
|
||||||
|
gemini_request,
|
||||||
|
api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
return JSONResponse({"status": "valid"})
|
||||||
|
return JSONResponse({"status": "invalid"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Key verification failed: {str(e)}")
|
||||||
|
return JSONResponse({"status": "invalid", "error": str(e)})
|
||||||
@@ -1,37 +1,46 @@
|
|||||||
from fastapi import HTTPException, APIRouter, Depends
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.core.logger import get_openai_logger
|
|
||||||
from app.core.security import SecurityService
|
from app.core.security import SecurityService
|
||||||
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
from app.domain.openai_models import (
|
||||||
from app.services.chat.retry_handler import RetryHandler
|
ChatRequest,
|
||||||
from app.services.embedding_service import EmbeddingService
|
EmbeddingRequest,
|
||||||
from app.services.image_create_service import ImageCreateService
|
ImageGenerationRequest,
|
||||||
from app.services.key_manager import KeyManager, get_key_manager_instance
|
)
|
||||||
from app.services.model_service import ModelService
|
from app.handler.retry_handler import RetryHandler
|
||||||
from app.services.openai_chat_service import OpenAIChatService
|
from app.log.logger import get_openai_logger
|
||||||
|
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||||
|
from app.service.embedding.embedding_service import EmbeddingService
|
||||||
|
from app.service.image.image_create_service import ImageCreateService
|
||||||
|
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||||
|
from app.service.model.model_service import ModelService
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = get_openai_logger()
|
logger = get_openai_logger()
|
||||||
|
|
||||||
# 初始化服务
|
# 初始化服务
|
||||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||||
model_service = ModelService(settings.MODEL_SEARCH)
|
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
|
||||||
embedding_service = EmbeddingService(settings.BASE_URL)
|
embedding_service = EmbeddingService(settings.BASE_URL)
|
||||||
image_create_service = ImageCreateService()
|
image_create_service = ImageCreateService()
|
||||||
|
|
||||||
|
|
||||||
async def get_key_manager():
|
async def get_key_manager():
|
||||||
return await get_key_manager_instance()
|
return await get_key_manager_instance()
|
||||||
|
|
||||||
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
|
||||||
|
async def get_next_working_key_wrapper(
|
||||||
|
key_manager: KeyManager = Depends(get_key_manager),
|
||||||
|
):
|
||||||
return await key_manager.get_next_working_key()
|
return await key_manager.get_next_working_key()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
@router.get("/hf/v1/models")
|
@router.get("/hf/v1/models")
|
||||||
async def list_models(
|
async def list_models(
|
||||||
_=Depends(security_service.verify_authorization),
|
_=Depends(security_service.verify_authorization),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager),
|
||||||
):
|
):
|
||||||
logger.info("-" * 50 + "list_models" + "-" * 50)
|
logger.info("-" * 50 + "list_models" + "-" * 50)
|
||||||
logger.info("Handling models list request")
|
logger.info("Handling models list request")
|
||||||
@@ -41,7 +50,9 @@ async def list_models(
|
|||||||
return model_service.get_gemini_openai_models(api_key)
|
return model_service.get_gemini_openai_models(api_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting models list: {str(e)}")
|
logger.error(f"Error getting models list: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Internal server error while fetching models list"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
@@ -51,7 +62,7 @@ async def chat_completion(
|
|||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
_=Depends(security_service.verify_authorization),
|
_=Depends(security_service.verify_authorization),
|
||||||
api_key: str = Depends(get_next_working_key_wrapper),
|
api_key: str = Depends(get_next_working_key_wrapper),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager),
|
||||||
):
|
):
|
||||||
# 如果model是imagen3,使用paid_key
|
# 如果model是imagen3,使用paid_key
|
||||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||||
@@ -61,6 +72,12 @@ async def chat_completion(
|
|||||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
if not model_service.check_model_support(request.model):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Model {request.model} is not supported"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 如果model是imagen3,使用paid_key
|
# 如果model是imagen3,使用paid_key
|
||||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||||
@@ -76,6 +93,7 @@ async def chat_completion(
|
|||||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/images/generations")
|
@router.post("/v1/images/generations")
|
||||||
@router.post("/hf/v1/images/generations")
|
@router.post("/hf/v1/images/generations")
|
||||||
async def generate_image(
|
async def generate_image(
|
||||||
@@ -91,14 +109,17 @@ async def generate_image(
|
|||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Image generation request failed: {str(e)}")
|
logger.error(f"Image generation request failed: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Image generation request failed") from e
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Image generation request failed"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/embeddings")
|
@router.post("/v1/embeddings")
|
||||||
@router.post("/hf/v1/embeddings")
|
@router.post("/hf/v1/embeddings")
|
||||||
async def embedding(
|
async def embedding(
|
||||||
request: EmbeddingRequest,
|
request: EmbeddingRequest,
|
||||||
_=Depends(security_service.verify_authorization),
|
_=Depends(security_service.verify_authorization),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager),
|
||||||
):
|
):
|
||||||
logger.info("-" * 50 + "embedding" + "-" * 50)
|
logger.info("-" * 50 + "embedding" + "-" * 50)
|
||||||
logger.info(f"Handling embedding request for model: {request.model}")
|
logger.info(f"Handling embedding request for model: {request.model}")
|
||||||
@@ -114,11 +135,12 @@ async def embedding(
|
|||||||
logger.error(f"Embedding request failed: {str(e)}")
|
logger.error(f"Embedding request failed: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/keys/list")
|
@router.get("/v1/keys/list")
|
||||||
@router.get("/hf/v1/keys/list")
|
@router.get("/hf/v1/keys/list")
|
||||||
async def get_keys_list(
|
async def get_keys_list(
|
||||||
_=Depends(security_service.verify_auth_token),
|
_=Depends(security_service.verify_auth_token),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager),
|
||||||
):
|
):
|
||||||
"""获取有效和无效的API key列表"""
|
"""获取有效和无效的API key列表"""
|
||||||
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
||||||
@@ -129,13 +151,12 @@ async def get_keys_list(
|
|||||||
"status": "success",
|
"status": "success",
|
||||||
"data": {
|
"data": {
|
||||||
"valid_keys": keys_status["valid_keys"],
|
"valid_keys": keys_status["valid_keys"],
|
||||||
"invalid_keys": keys_status["invalid_keys"]
|
"invalid_keys": keys_status["invalid_keys"],
|
||||||
},
|
},
|
||||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting keys list: {str(e)}")
|
logger.error(f"Error getting keys list: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail="Internal server error while fetching keys list"
|
||||||
detail="Internal server error while fetching keys list"
|
|
||||||
) from e
|
) from e
|
||||||
114
app/router/routes.py
Normal file
114
app/router/routes.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""
|
||||||
|
路由配置模块,负责设置和配置应用程序的路由
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from app.core.security import verify_auth_token
|
||||||
|
from app.log.logger import get_routes_logger
|
||||||
|
from app.router import gemini_routes, openai_routes
|
||||||
|
from app.service.key.key_manager import get_key_manager_instance
|
||||||
|
|
||||||
|
logger = get_routes_logger()
|
||||||
|
|
||||||
|
# 配置Jinja2模板
|
||||||
|
templates = Jinja2Templates(directory="app/templates")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_routers(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置应用程序的路由
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
# 包含API路由
|
||||||
|
app.include_router(openai_routes.router)
|
||||||
|
app.include_router(gemini_routes.router)
|
||||||
|
app.include_router(gemini_routes.router_v1beta)
|
||||||
|
|
||||||
|
# 添加页面路由
|
||||||
|
setup_page_routes(app)
|
||||||
|
|
||||||
|
# 添加健康检查路由
|
||||||
|
setup_health_routes(app)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_page_routes(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置页面相关的路由
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
async def auth_page(request: Request):
|
||||||
|
"""认证页面"""
|
||||||
|
return templates.TemplateResponse("auth.html", {"request": request})
|
||||||
|
|
||||||
|
@app.post("/auth")
|
||||||
|
async def authenticate(request: Request):
|
||||||
|
"""处理认证请求"""
|
||||||
|
try:
|
||||||
|
form = await request.form()
|
||||||
|
auth_token = form.get("auth_token")
|
||||||
|
if not auth_token:
|
||||||
|
logger.warning("Authentication attempt with empty token")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
|
if verify_auth_token(auth_token):
|
||||||
|
logger.info("Successful authentication")
|
||||||
|
response = RedirectResponse(url="/keys", status_code=302)
|
||||||
|
response.set_cookie(
|
||||||
|
key="auth_token", value=auth_token, httponly=True, max_age=3600
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
logger.warning("Failed authentication attempt with invalid token")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Authentication error: {str(e)}")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
|
@app.get("/keys", response_class=HTMLResponse)
|
||||||
|
async def keys_page(request: Request):
|
||||||
|
"""密钥管理页面"""
|
||||||
|
try:
|
||||||
|
auth_token = request.cookies.get("auth_token")
|
||||||
|
if not auth_token or not verify_auth_token(auth_token):
|
||||||
|
logger.warning("Unauthorized access attempt to keys page")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
|
key_manager = await get_key_manager_instance()
|
||||||
|
keys_status = await key_manager.get_keys_by_status()
|
||||||
|
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||||
|
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"keys_status.html",
|
||||||
|
{
|
||||||
|
"request": request,
|
||||||
|
"valid_keys": keys_status["valid_keys"],
|
||||||
|
"invalid_keys": keys_status["invalid_keys"],
|
||||||
|
"total": total,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving keys status: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def setup_health_routes(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置健康检查相关的路由
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check(request: Request):
|
||||||
|
"""健康检查端点"""
|
||||||
|
logger.info("Health check endpoint called")
|
||||||
|
return {"status": "healthy"}
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
# app/services/chat_service.py
|
# app/services/chat_service.py
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Any, AsyncGenerator, List
|
from typing import Any, AsyncGenerator, Dict, List
|
||||||
from app.core.logger import get_gemini_logger
|
|
||||||
from app.services.chat.api_client import GeminiApiClient
|
from app.config.config import settings
|
||||||
from app.services.chat.stream_optimizer import gemini_optimizer
|
from app.domain.gemini_models import GeminiRequest
|
||||||
from app.schemas.gemini_models import GeminiRequest
|
from app.handler.response_handler import GeminiResponseHandler
|
||||||
from app.core.config import settings
|
from app.handler.stream_optimizer import gemini_optimizer
|
||||||
from app.services.chat.response_handler import GeminiResponseHandler
|
from app.log.logger import get_gemini_logger
|
||||||
from app.services.key_manager import KeyManager
|
from app.service.client.api_client import GeminiApiClient
|
||||||
|
from app.service.key.key_manager import KeyManager
|
||||||
|
|
||||||
logger = get_gemini_logger()
|
logger = get_gemini_logger()
|
||||||
|
|
||||||
@@ -25,20 +26,43 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
|||||||
|
|
||||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
"""构建工具"""
|
"""构建工具"""
|
||||||
tools = []
|
|
||||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
|
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
model.endswith("-search") or "-thinking" in model
|
record = dict()
|
||||||
) and not _has_image_parts(payload.get("contents", [])):
|
for item in tools:
|
||||||
tools.append({"code_execution": {}})
|
if not item or not isinstance(item, dict):
|
||||||
if model.endswith("-search"):
|
continue
|
||||||
tools.append({"googleSearch": {}})
|
|
||||||
|
|
||||||
|
for k, v in item.items():
|
||||||
|
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||||
|
functions = record.get("functionDeclarations", [])
|
||||||
|
functions.extend(v)
|
||||||
|
record["functionDeclarations"] = functions
|
||||||
|
else:
|
||||||
|
record[k] = v
|
||||||
|
return record
|
||||||
|
|
||||||
|
tool = dict()
|
||||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||||
items = payload.get("tools", [])
|
items = payload.get("tools", [])
|
||||||
if items and isinstance(items, list):
|
if items and isinstance(items, list):
|
||||||
tools.extend(items)
|
tool.update(_merge_tools(items))
|
||||||
|
|
||||||
return tools
|
if (
|
||||||
|
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||||
|
and not (model.endswith("-search") or "-thinking" in model)
|
||||||
|
and not _has_image_parts(payload.get("contents", []))
|
||||||
|
):
|
||||||
|
tool["codeExecution"] = {}
|
||||||
|
if model.endswith("-search"):
|
||||||
|
tool["googleSearch"] = {}
|
||||||
|
|
||||||
|
# 解决 "Tool use with function calling is unsupported" 问题
|
||||||
|
if tool.get("functionDeclarations"):
|
||||||
|
tool.pop("googleSearch", None)
|
||||||
|
tool.pop("codeExecution", None)
|
||||||
|
|
||||||
|
return [tool]
|
||||||
|
|
||||||
|
|
||||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||||
@@ -49,28 +73,33 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
|||||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
|
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||||
]
|
]
|
||||||
return [
|
return [
|
||||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||||
"""构建请求payload"""
|
"""构建请求payload"""
|
||||||
payload = request.model_dump()
|
request_dict = request.model_dump()
|
||||||
return {
|
payload = {
|
||||||
"contents": payload.get("contents", []),
|
"contents": request_dict.get("contents", []),
|
||||||
"tools": _build_tools(model, payload),
|
"tools": _build_tools(model, request_dict),
|
||||||
"safetySettings": _get_safety_settings(model),
|
"safetySettings": _get_safety_settings(model),
|
||||||
"generationConfig": payload.get("generationConfig", {}),
|
"generationConfig": request_dict.get("generationConfig", {}),
|
||||||
"systemInstruction": payload.get("systemInstruction", [])
|
"systemInstruction": request_dict.get("systemInstruction", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||||
|
payload.pop("systemInstruction")
|
||||||
|
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
class GeminiChatService:
|
class GeminiChatService:
|
||||||
"""聊天服务"""
|
"""聊天服务"""
|
||||||
@@ -79,54 +108,68 @@ class GeminiChatService:
|
|||||||
self.api_client = GeminiApiClient(base_url)
|
self.api_client = GeminiApiClient(base_url)
|
||||||
self.key_manager = key_manager
|
self.key_manager = key_manager
|
||||||
self.response_handler = GeminiResponseHandler()
|
self.response_handler = GeminiResponseHandler()
|
||||||
|
|
||||||
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
||||||
"""从响应中提取文本内容"""
|
"""从响应中提取文本内容"""
|
||||||
if not response.get("candidates"):
|
if not response.get("candidates"):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
candidate = response["candidates"][0]
|
candidate = response["candidates"][0]
|
||||||
content = candidate.get("content", {})
|
content = candidate.get("content", {})
|
||||||
parts = content.get("parts", [])
|
parts = content.get("parts", [])
|
||||||
|
|
||||||
if parts and "text" in parts[0]:
|
if parts and "text" in parts[0]:
|
||||||
return parts[0].get("text", "")
|
return parts[0].get("text", "")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]:
|
def _create_char_response(
|
||||||
|
self, original_response: Dict[str, Any], text: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""创建包含指定文本的响应"""
|
"""创建包含指定文本的响应"""
|
||||||
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
||||||
if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"):
|
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
||||||
|
"content", {}
|
||||||
|
).get("parts"):
|
||||||
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
||||||
return response_copy
|
return response_copy
|
||||||
|
|
||||||
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
async def generate_content(
|
||||||
|
self, model: str, request: GeminiRequest, api_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""生成内容"""
|
"""生成内容"""
|
||||||
payload = _build_payload(model, request)
|
payload = _build_payload(model, request)
|
||||||
response = await self.api_client.generate_content(payload, model, api_key)
|
response = await self.api_client.generate_content(payload, model, api_key)
|
||||||
return self.response_handler.handle_response(response, model, stream=False)
|
return self.response_handler.handle_response(response, model, stream=False)
|
||||||
|
|
||||||
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
|
async def stream_generate_content(
|
||||||
|
self, model: str, request: GeminiRequest, api_key: str
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
"""流式生成内容"""
|
"""流式生成内容"""
|
||||||
retries = 0
|
retries = 0
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
payload = _build_payload(model, request)
|
payload = _build_payload(model, request)
|
||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
async for line in self.api_client.stream_generate_content(payload, model, api_key):
|
async for line in self.api_client.stream_generate_content(
|
||||||
|
payload, model, api_key
|
||||||
|
):
|
||||||
# print(line)
|
# print(line)
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
line = line[6:]
|
line = line[6:]
|
||||||
response_data = self.response_handler.handle_response(json.loads(line), model, stream=True)
|
response_data = self.response_handler.handle_response(
|
||||||
|
json.loads(line), model, stream=True
|
||||||
|
)
|
||||||
text = self._extract_text_from_response(response_data)
|
text = self._extract_text_from_response(response_data)
|
||||||
|
|
||||||
# 如果有文本内容,使用流式输出优化器处理
|
# 如果有文本内容,使用流式输出优化器处理
|
||||||
if text:
|
if text:
|
||||||
# 使用流式输出优化器处理文本输出
|
# 使用流式输出优化器处理文本输出
|
||||||
async for optimized_chunk in gemini_optimizer.optimize_stream_output(
|
async for (
|
||||||
|
optimized_chunk
|
||||||
|
) in gemini_optimizer.optimize_stream_output(
|
||||||
text,
|
text,
|
||||||
lambda t: self._create_char_response(response_data, t),
|
lambda t: self._create_char_response(response_data, t),
|
||||||
lambda c: "data: " + json.dumps(c) + "\n\n"
|
lambda c: "data: " + json.dumps(c) + "\n\n",
|
||||||
):
|
):
|
||||||
yield optimized_chunk
|
yield optimized_chunk
|
||||||
else:
|
else:
|
||||||
@@ -136,9 +179,13 @@ class GeminiChatService:
|
|||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
retries += 1
|
retries += 1
|
||||||
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
|
logger.warning(
|
||||||
|
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
|
||||||
|
)
|
||||||
api_key = await self.key_manager.handle_api_failure(api_key)
|
api_key = await self.key_manager.handle_api_failure(api_key)
|
||||||
logger.info(f"Switched to new API key: {api_key}")
|
logger.info(f"Switched to new API key: {api_key}")
|
||||||
if retries >= max_retries:
|
if retries >= max_retries:
|
||||||
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
|
logger.error(
|
||||||
|
f"Max retries ({max_retries}) reached for streaming. Raising error"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
@@ -1,17 +1,18 @@
|
|||||||
# app/services/chat_service.py
|
# app/services/chat_service.py
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
from copy import deepcopy
|
||||||
from app.core.logger import get_openai_logger
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
from app.services.chat.message_converter import OpenAIMessageConverter
|
|
||||||
from app.services.chat.response_handler import OpenAIResponseHandler
|
from app.config.config import settings
|
||||||
from app.services.chat.api_client import GeminiApiClient
|
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||||
from app.services.chat.stream_optimizer import openai_optimizer
|
from app.handler.message_converter import OpenAIMessageConverter
|
||||||
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
|
from app.handler.response_handler import OpenAIResponseHandler
|
||||||
from app.core.config import settings
|
from app.handler.stream_optimizer import openai_optimizer
|
||||||
from app.services.image_create_service import ImageCreateService
|
from app.log.logger import get_openai_logger
|
||||||
from app.services.key_manager import KeyManager
|
from app.service.client.api_client import GeminiApiClient
|
||||||
|
from app.service.image.image_create_service import ImageCreateService
|
||||||
|
from app.service.key.key_manager import KeyManager
|
||||||
|
|
||||||
logger = get_openai_logger()
|
logger = get_openai_logger()
|
||||||
|
|
||||||
@@ -27,30 +28,35 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _build_tools(
|
def _build_tools(
|
||||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""构建工具"""
|
"""构建工具"""
|
||||||
tools = []
|
tool = dict()
|
||||||
model = request.model
|
model = request.model
|
||||||
|
|
||||||
if (
|
if (
|
||||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||||
and not (model.endswith("-search") or "-thinking" in model)
|
and not (
|
||||||
and not _has_image_parts(messages)
|
model.endswith("-search")
|
||||||
|
or "-thinking" in model
|
||||||
|
or model.endswith("-image")
|
||||||
|
or model.endswith("-image-generation")
|
||||||
|
)
|
||||||
|
and not _has_image_parts(messages)
|
||||||
):
|
):
|
||||||
tools.append({"code_execution": {}})
|
tool["codeExecution"] = {}
|
||||||
if model.endswith("-search"):
|
if model.endswith("-search"):
|
||||||
tools.append({"googleSearch": {}})
|
tool["googleSearch"] = {}
|
||||||
|
|
||||||
# 将 request 中的 tools 合并到 tools 中
|
# 将 request 中的 tools 合并到 tools 中
|
||||||
if request.tools:
|
if request.tools:
|
||||||
function_declarations = []
|
function_declarations = []
|
||||||
for tool in request.tools:
|
for item in request.tools:
|
||||||
if not tool or not isinstance(tool, dict):
|
if not item or not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool.get("type", "") == "function" and tool.get("function"):
|
if item.get("type", "") == "function" and item.get("function"):
|
||||||
function = deepcopy(tool.get("function"))
|
function = deepcopy(item.get("function"))
|
||||||
parameters = function.get("parameters", {})
|
parameters = function.get("parameters", {})
|
||||||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||||||
function.pop("parameters", None)
|
function.pop("parameters", None)
|
||||||
@@ -60,14 +66,19 @@ def _build_tools(
|
|||||||
if function_declarations:
|
if function_declarations:
|
||||||
# 按照 function 的 name 去重
|
# 按照 function 的 name 去重
|
||||||
names, functions = set(), []
|
names, functions = set(), []
|
||||||
for item in function_declarations:
|
for fc in function_declarations:
|
||||||
if item.get("name") not in names:
|
if fc.get("name") not in names:
|
||||||
names.add(item.get("name"))
|
names.add(fc.get("name"))
|
||||||
functions.append(item)
|
functions.append(fc)
|
||||||
|
|
||||||
tools.append({"functionDeclarations": functions})
|
tool["functionDeclarations"] = functions
|
||||||
|
|
||||||
return tools
|
# 解决 "Tool use with function calling is unsupported" 问题
|
||||||
|
if tool.get("functionDeclarations"):
|
||||||
|
tool.pop("googleSearch", None)
|
||||||
|
tool.pop("codeExecution", None)
|
||||||
|
|
||||||
|
return [tool] if tool else []
|
||||||
|
|
||||||
|
|
||||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||||
@@ -95,7 +106,9 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
|||||||
|
|
||||||
|
|
||||||
def _build_payload(
|
def _build_payload(
|
||||||
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
|
request: ChatRequest,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
instruction: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""构建请求payload"""
|
"""构建请求payload"""
|
||||||
payload = {
|
payload = {
|
||||||
@@ -110,12 +123,16 @@ def _build_payload(
|
|||||||
"tools": _build_tools(request, messages),
|
"tools": _build_tools(request, messages),
|
||||||
"safetySettings": _get_safety_settings(request.model),
|
"safetySettings": _get_safety_settings(request.model),
|
||||||
}
|
}
|
||||||
|
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||||
|
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
instruction
|
instruction
|
||||||
and isinstance(instruction, dict)
|
and isinstance(instruction, dict)
|
||||||
and instruction.get("role") == "system"
|
and instruction.get("role") == "system"
|
||||||
and instruction.get("parts")
|
and instruction.get("parts")
|
||||||
|
and not request.model.endswith("-image")
|
||||||
|
and not request.model.endswith("-image-generation")
|
||||||
):
|
):
|
||||||
payload["systemInstruction"] = instruction
|
payload["systemInstruction"] = instruction
|
||||||
|
|
||||||
@@ -124,24 +141,27 @@ def _build_payload(
|
|||||||
|
|
||||||
class OpenAIChatService:
|
class OpenAIChatService:
|
||||||
"""聊天服务"""
|
"""聊天服务"""
|
||||||
|
|
||||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||||
self.message_converter = OpenAIMessageConverter()
|
self.message_converter = OpenAIMessageConverter()
|
||||||
self.response_handler = OpenAIResponseHandler(config=None)
|
self.response_handler = OpenAIResponseHandler(config=None)
|
||||||
self.api_client = GeminiApiClient(base_url)
|
self.api_client = GeminiApiClient(base_url)
|
||||||
self.key_manager = key_manager
|
self.key_manager = key_manager
|
||||||
self.image_create_service = ImageCreateService()
|
self.image_create_service = ImageCreateService()
|
||||||
|
|
||||||
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
||||||
"""从OpenAI响应块中提取文本内容"""
|
"""从OpenAI响应块中提取文本内容"""
|
||||||
if not chunk.get("choices"):
|
if not chunk.get("choices"):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
if "delta" in choice and "content" in choice["delta"]:
|
if "delta" in choice and "content" in choice["delta"]:
|
||||||
return choice["delta"]["content"]
|
return choice["delta"]["content"]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]:
|
def _create_char_openai_chunk(
|
||||||
|
self, original_chunk: Dict[str, Any], text: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""创建包含指定文本的OpenAI响应块"""
|
"""创建包含指定文本的OpenAI响应块"""
|
||||||
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
||||||
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
||||||
@@ -149,9 +169,9 @@ class OpenAIChatService:
|
|||||||
return chunk_copy
|
return chunk_copy
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||||
"""创建聊天完成"""
|
"""创建聊天完成"""
|
||||||
# 转换消息格式
|
# 转换消息格式
|
||||||
@@ -165,7 +185,7 @@ class OpenAIChatService:
|
|||||||
return await self._handle_normal_completion(request.model, payload, api_key)
|
return await self._handle_normal_completion(request.model, payload, api_key)
|
||||||
|
|
||||||
async def _handle_normal_completion(
|
async def _handle_normal_completion(
|
||||||
self, model: str, payload: Dict[str, Any], api_key: str
|
self, model: str, payload: Dict[str, Any], api_key: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""处理普通聊天完成"""
|
"""处理普通聊天完成"""
|
||||||
response = await self.api_client.generate_content(payload, model, api_key)
|
response = await self.api_client.generate_content(payload, model, api_key)
|
||||||
@@ -174,15 +194,16 @@ class OpenAIChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_stream_completion(
|
async def _handle_stream_completion(
|
||||||
self, model: str, payload: Dict[str, Any], api_key: str
|
self, model: str, payload: Dict[str, Any], api_key: str
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""处理流式聊天完成,添加重试逻辑"""
|
"""处理流式聊天完成,添加重试逻辑"""
|
||||||
retries = 0
|
retries = 0
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
|
tool_call_flag = False
|
||||||
async for line in self.api_client.stream_generate_content(
|
async for line in self.api_client.stream_generate_content(
|
||||||
payload, model, api_key
|
payload, model, api_key
|
||||||
):
|
):
|
||||||
# print(line)
|
# print(line)
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
@@ -195,16 +216,25 @@ class OpenAIChatService:
|
|||||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||||
if text:
|
if text:
|
||||||
# 使用流式输出优化器处理文本输出
|
# 使用流式输出优化器处理文本输出
|
||||||
async for optimized_chunk in openai_optimizer.optimize_stream_output(
|
async for (
|
||||||
|
optimized_chunk
|
||||||
|
) in openai_optimizer.optimize_stream_output(
|
||||||
text,
|
text,
|
||||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
lambda t: self._create_char_openai_chunk(
|
||||||
lambda c: f"data: {json.dumps(c)}\n\n"
|
openai_chunk, t
|
||||||
|
),
|
||||||
|
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||||
):
|
):
|
||||||
yield optimized_chunk
|
yield optimized_chunk
|
||||||
else:
|
else:
|
||||||
# 如果没有文本内容(如工具调用等),整块输出
|
# 如果没有文本内容(如工具调用等),整块输出
|
||||||
|
if "tool_calls" in json.dumps(openai_chunk):
|
||||||
|
tool_call_flag = True
|
||||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
if tool_call_flag:
|
||||||
|
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls'))}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
logger.info("Streaming completed successfully")
|
logger.info("Streaming completed successfully")
|
||||||
break # 成功后退出循环
|
break # 成功后退出循环
|
||||||
@@ -224,21 +254,23 @@ class OpenAIChatService:
|
|||||||
break
|
break
|
||||||
|
|
||||||
async def create_image_chat_completion(
|
async def create_image_chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||||
|
|
||||||
image_generate_request = ImageGenerationRequest()
|
image_generate_request = ImageGenerationRequest()
|
||||||
image_generate_request.prompt = request.messages[-1]["content"]
|
image_generate_request.prompt = request.messages[-1]["content"]
|
||||||
image_res = self.image_create_service.generate_images_chat(image_generate_request)
|
image_res = self.image_create_service.generate_images_chat(
|
||||||
|
image_generate_request
|
||||||
|
)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._handle_stream_image_completion(request.model,image_res)
|
return self._handle_stream_image_completion(request.model, image_res)
|
||||||
else:
|
else:
|
||||||
return self._handle_normal_image_completion(request.model,image_res)
|
return self._handle_normal_image_completion(request.model, image_res)
|
||||||
|
|
||||||
async def _handle_stream_image_completion(
|
async def _handle_stream_image_completion(
|
||||||
self, model: str, image_data: str
|
self, model: str, image_data: str
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if image_data:
|
if image_data:
|
||||||
openai_chunk = self.response_handler.handle_image_chat_response(
|
openai_chunk = self.response_handler.handle_image_chat_response(
|
||||||
@@ -249,10 +281,12 @@ class OpenAIChatService:
|
|||||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||||
if text:
|
if text:
|
||||||
# 使用流式输出优化器处理文本输出
|
# 使用流式输出优化器处理文本输出
|
||||||
async for optimized_chunk in openai_optimizer.optimize_stream_output(
|
async for (
|
||||||
|
optimized_chunk
|
||||||
|
) in openai_optimizer.optimize_stream_output(
|
||||||
text,
|
text,
|
||||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||||
lambda c: f"data: {json.dumps(c)}\n\n"
|
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||||
):
|
):
|
||||||
yield optimized_chunk
|
yield optimized_chunk
|
||||||
else:
|
else:
|
||||||
@@ -261,11 +295,11 @@ class OpenAIChatService:
|
|||||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
logger.info("Image chat streaming completed successfully")
|
logger.info("Image chat streaming completed successfully")
|
||||||
|
|
||||||
def _handle_normal_image_completion(
|
def _handle_normal_image_completion(
|
||||||
self, model: str, image_data: str
|
self, model: str, image_data: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
return self.response_handler.handle_image_chat_response(
|
return self.response_handler.handle_image_chat_response(
|
||||||
image_data, model, stream=False, finish_reason="stop"
|
image_data, model, stream=False, finish_reason="stop"
|
||||||
)
|
)
|
||||||
@@ -4,6 +4,8 @@ from typing import Dict, Any, AsyncGenerator
|
|||||||
import httpx
|
import httpx
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from app.core.constants import DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
class ApiClient(ABC):
|
class ApiClient(ABC):
|
||||||
"""API客户端基类"""
|
"""API客户端基类"""
|
||||||
@@ -20,14 +22,22 @@ class ApiClient(ABC):
|
|||||||
class GeminiApiClient(ApiClient):
|
class GeminiApiClient(ApiClient):
|
||||||
"""Gemini API客户端"""
|
"""Gemini API客户端"""
|
||||||
|
|
||||||
def __init__(self, base_url: str, timeout: int = 300):
|
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
def _get_real_model(self, model: str) -> str:
|
||||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
|
||||||
if model.endswith("-search"):
|
if model.endswith("-search"):
|
||||||
model = model[:-7]
|
model = model[:-7]
|
||||||
|
if model.endswith("-image"):
|
||||||
|
model = model[:-6]
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||||
|
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||||
|
model = self._get_real_model(model)
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||||
response = await client.post(url, json=payload)
|
response = await client.post(url, json=payload)
|
||||||
@@ -38,8 +48,8 @@ class GeminiApiClient(ApiClient):
|
|||||||
|
|
||||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||||
if model.endswith("-search"):
|
model = self._get_real_model(model)
|
||||||
model = model[:-7]
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
from typing import Union, List
|
from typing import List, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai.types import CreateEmbeddingResponse
|
from openai.types import CreateEmbeddingResponse
|
||||||
|
|
||||||
from app.core.logger import get_embeddings_logger
|
from app.log.logger import get_embeddings_logger
|
||||||
|
|
||||||
logger = get_embeddings_logger()
|
logger = get_embeddings_logger()
|
||||||
|
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
|
import base64
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
import base64
|
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.core.logger import get_image_create_logger
|
from app.core.constants import VALID_IMAGE_RATIOS
|
||||||
from app.core.uploader import ImageUploaderFactory
|
from app.domain.openai_models import ImageGenerationRequest
|
||||||
from app.schemas.openai_models import ImageGenerationRequest
|
from app.log.logger import get_image_create_logger
|
||||||
|
from app.utils.uploader import ImageUploaderFactory
|
||||||
|
|
||||||
logger = get_image_create_logger()
|
logger = get_image_create_logger()
|
||||||
|
|
||||||
@@ -26,35 +27,34 @@ class ImageCreateService:
|
|||||||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 默认值
|
# 默认值
|
||||||
n = 1
|
n = 1
|
||||||
aspect_ratio = self.aspect_ratio
|
aspect_ratio = self.aspect_ratio
|
||||||
|
|
||||||
# 解析n参数
|
# 解析n参数
|
||||||
n_match = re.search(r'{n:(\d+)}', prompt)
|
n_match = re.search(r"{n:(\d+)}", prompt)
|
||||||
if n_match:
|
if n_match:
|
||||||
n = int(n_match.group(1))
|
n = int(n_match.group(1))
|
||||||
if n < 1 or n > 4:
|
if n < 1 or n > 4:
|
||||||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||||
prompt = prompt.replace(n_match.group(0), '').strip()
|
prompt = prompt.replace(n_match.group(0), "").strip()
|
||||||
|
|
||||||
# 解析ratio参数
|
# 解析ratio参数
|
||||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
|
||||||
if ratio_match:
|
if ratio_match:
|
||||||
aspect_ratio = ratio_match.group(1)
|
aspect_ratio = ratio_match.group(1)
|
||||||
valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||||
if aspect_ratio not in valid_ratios:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}"
|
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||||
)
|
)
|
||||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
prompt = prompt.replace(ratio_match.group(0), "").strip()
|
||||||
|
|
||||||
return prompt, n, aspect_ratio
|
return prompt, n, aspect_ratio
|
||||||
|
|
||||||
def generate_images(self, request: ImageGenerationRequest):
|
def generate_images(self, request: ImageGenerationRequest):
|
||||||
client = genai.Client(api_key=self.paid_key)
|
client = genai.Client(api_key=self.paid_key)
|
||||||
|
|
||||||
if request.size == "1024x1024":
|
if request.size == "1024x1024":
|
||||||
self.aspect_ratio = "1:1"
|
self.aspect_ratio = "1:1"
|
||||||
elif request.size == "1792x1024":
|
elif request.size == "1792x1024":
|
||||||
@@ -67,13 +67,15 @@ class ImageCreateService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 解析prompt中的参数
|
# 解析prompt中的参数
|
||||||
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt)
|
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
|
||||||
|
request.prompt
|
||||||
|
)
|
||||||
request.prompt = cleaned_prompt
|
request.prompt = cleaned_prompt
|
||||||
|
|
||||||
# 如果prompt中指定了n,则覆盖请求中的n
|
# 如果prompt中指定了n,则覆盖请求中的n
|
||||||
if prompt_n > 1:
|
if prompt_n > 1:
|
||||||
request.n = prompt_n
|
request.n = prompt_n
|
||||||
|
|
||||||
# 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
|
# 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
|
||||||
if prompt_ratio != self.aspect_ratio:
|
if prompt_ratio != self.aspect_ratio:
|
||||||
self.aspect_ratio = prompt_ratio
|
self.aspect_ratio = prompt_ratio
|
||||||
@@ -96,27 +98,49 @@ class ImageCreateService:
|
|||||||
for index, generated_image in enumerate(response.generated_images):
|
for index, generated_image in enumerate(response.generated_images):
|
||||||
image_data = generated_image.image.image_bytes
|
image_data = generated_image.image.image_bytes
|
||||||
image_uploader = None
|
image_uploader = None
|
||||||
if settings.UPLOAD_PROVIDER == "smms":
|
|
||||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
if request.response_format == "b64_json":
|
||||||
|
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
images_data.append(
|
||||||
|
{"b64_json": base64_image, "revised_prompt": request.prompt}
|
||||||
|
)
|
||||||
|
else:
|
||||||
current_date = time.strftime("%Y/%m/%d")
|
current_date = time.strftime("%Y/%m/%d")
|
||||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||||
upload_response = image_uploader.upload(image_data,filename)
|
|
||||||
|
if settings.UPLOAD_PROVIDER == "smms":
|
||||||
if request.response_format == "b64_json":
|
image_uploader = ImageUploaderFactory.create(
|
||||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
provider=settings.UPLOAD_PROVIDER,
|
||||||
images_data.append({
|
api_key=settings.SMMS_SECRET_TOKEN,
|
||||||
"b64_json": base64_image,
|
)
|
||||||
"revised_prompt": request.prompt
|
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||||
})
|
image_uploader = ImageUploaderFactory.create(
|
||||||
else:
|
provider=settings.UPLOAD_PROVIDER,
|
||||||
images_data.append({
|
api_key=settings.PICGO_API_KEY,
|
||||||
"url": f"{upload_response.data.url}",
|
)
|
||||||
"revised_prompt": request.prompt
|
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||||
})
|
image_uploader = ImageUploaderFactory.create(
|
||||||
|
provider=settings.UPLOAD_PROVIDER,
|
||||||
|
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||||
|
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_response = image_uploader.upload(image_data, filename)
|
||||||
|
|
||||||
|
images_data.append(
|
||||||
|
{
|
||||||
|
"url": f"{upload_response.data.url}",
|
||||||
|
"revised_prompt": request.prompt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"created": int(time.time()), # Current timestamp
|
"created": int(time.time()), # Current timestamp
|
||||||
"data": images_data
|
"data": images_data,
|
||||||
}
|
}
|
||||||
return response_data
|
return response_data
|
||||||
else:
|
else:
|
||||||
@@ -128,9 +152,13 @@ class ImageCreateService:
|
|||||||
if image_datas:
|
if image_datas:
|
||||||
markdown_images = []
|
markdown_images = []
|
||||||
for index, image_data in enumerate(image_datas):
|
for index, image_data in enumerate(image_datas):
|
||||||
if 'url' in image_data:
|
if "url" in image_data:
|
||||||
markdown_images.append(f"")
|
markdown_images.append(
|
||||||
|
f""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 如果是base64格式,创建data URL
|
# 如果是base64格式,创建data URL
|
||||||
markdown_images.append(f"")
|
markdown_images.append(
|
||||||
|
f""
|
||||||
|
)
|
||||||
return "\n".join(markdown_images)
|
return "\n".join(markdown_images)
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from app.core.logger import get_key_manager_logger
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.log.logger import get_key_manager_logger
|
||||||
|
|
||||||
logger = get_key_manager_logger()
|
logger = get_key_manager_logger()
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ class KeyManager:
|
|||||||
|
|
||||||
async def get_paid_key(self) -> str:
|
async def get_paid_key(self) -> str:
|
||||||
return self.paid_key
|
return self.paid_key
|
||||||
|
|
||||||
async def get_next_key(self) -> str:
|
async def get_next_key(self) -> str:
|
||||||
"""获取下一个API key"""
|
"""获取下一个API key"""
|
||||||
async with self.key_cycle_lock:
|
async with self.key_cycle_lock:
|
||||||
@@ -70,7 +70,7 @@ class KeyManager:
|
|||||||
"""获取分类后的API key列表,包括失败次数"""
|
"""获取分类后的API key列表,包括失败次数"""
|
||||||
valid_keys = {}
|
valid_keys = {}
|
||||||
invalid_keys = {}
|
invalid_keys = {}
|
||||||
|
|
||||||
async with self.failure_count_lock:
|
async with self.failure_count_lock:
|
||||||
for key in self.api_keys:
|
for key in self.api_keys:
|
||||||
fail_count = self.key_failure_counts[key]
|
fail_count = self.key_failure_counts[key]
|
||||||
@@ -78,16 +78,14 @@ class KeyManager:
|
|||||||
valid_keys[key] = fail_count
|
valid_keys[key] = fail_count
|
||||||
else:
|
else:
|
||||||
invalid_keys[key] = fail_count
|
invalid_keys[key] = fail_count
|
||||||
|
|
||||||
return {
|
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||||
"valid_keys": valid_keys,
|
|
||||||
"invalid_keys": invalid_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_singleton_instance = None
|
_singleton_instance = None
|
||||||
_singleton_lock = asyncio.Lock()
|
_singleton_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
|
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
|
||||||
"""
|
"""
|
||||||
获取 KeyManager 单例实例。
|
获取 KeyManager 单例实例。
|
||||||
@@ -1,15 +1,20 @@
|
|||||||
import requests
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, Dict, Any
|
from typing import Any, Dict, Optional
|
||||||
from app.core.logger import get_model_logger
|
|
||||||
from app.core.config import settings
|
import requests
|
||||||
|
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.log.logger import get_model_logger
|
||||||
|
|
||||||
logger = get_model_logger()
|
logger = get_model_logger()
|
||||||
|
|
||||||
|
|
||||||
class ModelService:
|
class ModelService:
|
||||||
def __init__(self, model_search: list):
|
def __init__(self, search_models: list, image_models: list):
|
||||||
self.model_search = model_search
|
self.search_models = search_models
|
||||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
self.image_models = image_models
|
||||||
|
self.base_url = settings.BASE_URL
|
||||||
|
self.filtered_models = settings.FILTERED_MODELS
|
||||||
|
|
||||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||||
url = f"{self.base_url}/models?key={api_key}"
|
url = f"{self.base_url}/models?key={api_key}"
|
||||||
@@ -18,6 +23,16 @@ class ModelService:
|
|||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
gemini_models = response.json()
|
gemini_models = response.json()
|
||||||
|
|
||||||
|
filtered_models_list = []
|
||||||
|
for model in gemini_models.get("models", []):
|
||||||
|
model_id = model["name"].split("/")[-1]
|
||||||
|
if model_id not in self.filtered_models:
|
||||||
|
filtered_models_list.append(model)
|
||||||
|
else:
|
||||||
|
logger.info(f"Filtered out model: {model_id}")
|
||||||
|
|
||||||
|
gemini_models["models"] = filtered_models_list
|
||||||
return gemini_models
|
return gemini_models
|
||||||
else:
|
else:
|
||||||
logger.error(f"Error: {response.status_code}")
|
logger.error(f"Error: {response.status_code}")
|
||||||
@@ -36,7 +51,7 @@ class ModelService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def convert_to_openai_models_format(
|
def convert_to_openai_models_format(
|
||||||
self, gemini_models: Dict[str, Any]
|
self, gemini_models: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
openai_format = {"object": "list", "data": [], "success": True}
|
openai_format = {"object": "list", "data": [], "success": True}
|
||||||
|
|
||||||
@@ -53,13 +68,31 @@ class ModelService:
|
|||||||
}
|
}
|
||||||
openai_format["data"].append(openai_model)
|
openai_format["data"].append(openai_model)
|
||||||
|
|
||||||
if model_id in self.model_search:
|
if model_id in self.search_models:
|
||||||
search_model = openai_model.copy()
|
search_model = openai_model.copy()
|
||||||
search_model["id"] = f"{model_id}-search"
|
search_model["id"] = f"{model_id}-search"
|
||||||
openai_format["data"].append(search_model)
|
openai_format["data"].append(search_model)
|
||||||
|
if model_id in self.image_models:
|
||||||
|
image_model = openai_model.copy()
|
||||||
|
image_model["id"] = f"{model_id}-image"
|
||||||
|
openai_format["data"].append(image_model)
|
||||||
|
|
||||||
if settings.CREATE_IMAGE_MODEL:
|
if settings.CREATE_IMAGE_MODEL:
|
||||||
image_model = openai_model.copy()
|
image_model = openai_model.copy()
|
||||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||||
openai_format["data"].append(image_model)
|
openai_format["data"].append(image_model)
|
||||||
return openai_format
|
return openai_format
|
||||||
|
|
||||||
|
def check_model_support(self, model: str) -> bool:
|
||||||
|
if not model or not isinstance(model, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
model = model.strip()
|
||||||
|
if model.endswith("-search"):
|
||||||
|
model = model[:-7]
|
||||||
|
return model in self.search_models
|
||||||
|
if model.endswith("-image"):
|
||||||
|
model = model[:-6]
|
||||||
|
return model in self.image_models
|
||||||
|
|
||||||
|
return model not in self.filtered_models
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
# app/services/chat/message_converter.py
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
|
||||||
|
|
||||||
|
|
||||||
class MessageConverter(ABC):
|
|
||||||
"""消息转换器基类"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_image(image_url: str) -> Dict[str, Any]:
|
|
||||||
if image_url.startswith("data:image"):
|
|
||||||
return {
|
|
||||||
"inline_data": {
|
|
||||||
"mime_type": "image/jpeg",
|
|
||||||
"data": image_url.split(",")[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"image_url": {
|
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMessageConverter(MessageConverter):
|
|
||||||
"""OpenAI消息格式转换器"""
|
|
||||||
|
|
||||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
||||||
converted_messages = []
|
|
||||||
system_instruction = None
|
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
|
||||||
role = msg.get("role", "")
|
|
||||||
if role not in SUPPORTED_ROLES:
|
|
||||||
if role == "tool":
|
|
||||||
role = "user"
|
|
||||||
else:
|
|
||||||
# 如果是最后一条消息,则认为是用户消息
|
|
||||||
if idx == len(messages) - 1:
|
|
||||||
role = "user"
|
|
||||||
else:
|
|
||||||
role = "model"
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
if isinstance(msg["content"], str) and msg["content"]:
|
|
||||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
|
||||||
parts.append({"text": msg["content"]})
|
|
||||||
elif isinstance(msg["content"], list):
|
|
||||||
for content in msg["content"]:
|
|
||||||
if isinstance(content, str) and content:
|
|
||||||
parts.append({"text": content})
|
|
||||||
elif isinstance(content, dict):
|
|
||||||
if content["type"] == "text" and content["text"]:
|
|
||||||
parts.append({"text": content["text"]})
|
|
||||||
elif content["type"] == "image_url":
|
|
||||||
parts.append(_convert_image(content["image_url"]["url"]))
|
|
||||||
|
|
||||||
if parts:
|
|
||||||
if role == "system":
|
|
||||||
system_instruction = {"role": "system", "parts": parts}
|
|
||||||
else:
|
|
||||||
converted_messages.append({"role": role, "parts": parts})
|
|
||||||
|
|
||||||
return converted_messages, system_instruction
|
|
||||||
3
app/utils/__init__.py
Normal file
3
app/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
工具包初始化模块
|
||||||
|
"""
|
||||||
146
app/utils/helpers.py
Normal file
146
app/utils/helpers.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
通用工具函数模块
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
|
||||||
|
|
||||||
|
|
||||||
|
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
从 base64 字符串中提取 MIME 类型和数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base64_string: 可能包含 MIME 类型信息的 base64 字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (mime_type, encoded_data)
|
||||||
|
"""
|
||||||
|
# 检查字符串是否以 "data:" 格式开始
|
||||||
|
if base64_string.startswith('data:'):
|
||||||
|
# 提取 MIME 类型和数据
|
||||||
|
pattern = DATA_URL_PATTERN
|
||||||
|
match = re.match(pattern, base64_string)
|
||||||
|
if match:
|
||||||
|
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||||
|
encoded_data = match.group(2)
|
||||||
|
return mime_type, encoded_data
|
||||||
|
|
||||||
|
# 如果不是预期格式,假定它只是数据部分
|
||||||
|
return None, base64_string
|
||||||
|
|
||||||
|
|
||||||
|
def convert_image_to_base64(url: str) -> str:
|
||||||
|
"""
|
||||||
|
将图片URL转换为base64编码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 图片URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: base64编码的图片数据
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果获取图片失败
|
||||||
|
"""
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# 将图片内容转换为base64
|
||||||
|
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||||
|
return img_data
|
||||||
|
else:
|
||||||
|
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
|
||||||
|
"""
|
||||||
|
格式化JSON响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 要格式化的数据
|
||||||
|
indent: 缩进空格数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的JSON字符串
|
||||||
|
"""
|
||||||
|
return json.dumps(data, indent=indent, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
|
||||||
|
"""
|
||||||
|
从prompt中解析参数
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- {n:数量} 例如: {n:2} 生成2张图片
|
||||||
|
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示文本
|
||||||
|
default_ratio: 默认比例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (清理后的提示文本, 图片数量, 比例)
|
||||||
|
"""
|
||||||
|
# 默认值
|
||||||
|
n = 1
|
||||||
|
aspect_ratio = default_ratio
|
||||||
|
|
||||||
|
# 解析n参数
|
||||||
|
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||||
|
if n_match:
|
||||||
|
n = int(n_match.group(1))
|
||||||
|
if n < 1 or n > 4:
|
||||||
|
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||||
|
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||||
|
|
||||||
|
# 解析ratio参数
|
||||||
|
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||||
|
if ratio_match:
|
||||||
|
aspect_ratio = ratio_match.group(1)
|
||||||
|
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||||
|
)
|
||||||
|
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||||
|
|
||||||
|
return prompt, n, aspect_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
从Markdown文本中提取图片URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Markdown文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 图片URL列表
|
||||||
|
"""
|
||||||
|
pattern = IMAGE_URL_PATTERN
|
||||||
|
matches = re.findall(pattern, text)
|
||||||
|
return [match[1] for match in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_api_key(key: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查API密钥格式是否有效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果密钥格式有效则返回True
|
||||||
|
"""
|
||||||
|
# 检查Gemini API密钥格式
|
||||||
|
if key.startswith('AIza'):
|
||||||
|
return len(key) >= 30
|
||||||
|
|
||||||
|
# 检查OpenAI API密钥格式
|
||||||
|
if key.startswith('sk-'):
|
||||||
|
return len(key) >= 30
|
||||||
|
|
||||||
|
return False
|
||||||
393
app/utils/uploader.py
Normal file
393
app/utils/uploader.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
import requests
|
||||||
|
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
class UploadErrorType(Enum):
|
||||||
|
"""上传错误类型枚举"""
|
||||||
|
NETWORK_ERROR = "network_error" # 网络请求错误
|
||||||
|
AUTH_ERROR = "auth_error" # 认证错误
|
||||||
|
INVALID_FILE = "invalid_file" # 无效文件
|
||||||
|
SERVER_ERROR = "server_error" # 服务器错误
|
||||||
|
PARSE_ERROR = "parse_error" # 响应解析错误
|
||||||
|
UNKNOWN = "unknown" # 未知错误
|
||||||
|
|
||||||
|
|
||||||
|
class UploadError(Exception):
|
||||||
|
"""图片上传错误异常类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
error_type: UploadErrorType = UploadErrorType.UNKNOWN,
|
||||||
|
status_code: Optional[int] = None,
|
||||||
|
details: Optional[dict] = None,
|
||||||
|
original_error: Optional[Exception] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化上传错误异常
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 错误消息
|
||||||
|
error_type: 错误类型
|
||||||
|
status_code: HTTP状态码
|
||||||
|
details: 详细错误信息
|
||||||
|
original_error: 原始异常
|
||||||
|
"""
|
||||||
|
self.message = message
|
||||||
|
self.error_type = error_type
|
||||||
|
self.status_code = status_code
|
||||||
|
self.details = details or {}
|
||||||
|
self.original_error = original_error
|
||||||
|
|
||||||
|
# 构建完整错误信息
|
||||||
|
full_message = f"[{error_type.value}] {message}"
|
||||||
|
if status_code:
|
||||||
|
full_message = f"{full_message} (Status: {status_code})"
|
||||||
|
if details:
|
||||||
|
full_message = f"{full_message} - Details: {details}"
|
||||||
|
|
||||||
|
super().__init__(full_message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError":
|
||||||
|
"""
|
||||||
|
从HTTP响应创建错误实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: HTTP响应对象
|
||||||
|
message: 自定义错误消息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
details = error_data.get("data", {})
|
||||||
|
return cls(
|
||||||
|
message=message or error_data.get("message", "Unknown error"),
|
||||||
|
error_type=UploadErrorType.SERVER_ERROR,
|
||||||
|
status_code=response.status_code,
|
||||||
|
details=details
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return cls(
|
||||||
|
message=message or "Failed to parse error response",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR,
|
||||||
|
status_code=response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SmMsUploader(ImageUploader):
|
||||||
|
API_URL = "https://sm.ms/api/v2/upload"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||||
|
try:
|
||||||
|
# 准备请求头
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Basic {self.api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 准备文件数据
|
||||||
|
files = {
|
||||||
|
"smfile": (filename, file, "image/png")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
self.API_URL,
|
||||||
|
headers=headers,
|
||||||
|
files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查响应状态
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# 验证上传是否成功
|
||||||
|
if not result.get("success"):
|
||||||
|
raise UploadError(result.get("message", "Upload failed"))
|
||||||
|
|
||||||
|
# 转换为统一格式
|
||||||
|
data = result["data"]
|
||||||
|
image_metadata = ImageMetadata(
|
||||||
|
width=data["width"],
|
||||||
|
height=data["height"],
|
||||||
|
filename=data["filename"],
|
||||||
|
size=data["size"],
|
||||||
|
url=data["url"],
|
||||||
|
delete_url=data["delete"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return UploadResponse(
|
||||||
|
success=True,
|
||||||
|
code="success",
|
||||||
|
message="Upload success",
|
||||||
|
data=image_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
# 处理网络请求相关错误
|
||||||
|
raise UploadError(f"Upload request failed: {str(e)}")
|
||||||
|
except (KeyError, ValueError) as e:
|
||||||
|
# 处理响应解析错误
|
||||||
|
raise UploadError(f"Invalid response format: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他未预期的错误
|
||||||
|
raise UploadError(f"Upload failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class QiniuUploader(ImageUploader):
|
||||||
|
def __init__(self, access_key: str, secret_key: str):
|
||||||
|
self.access_key = access_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
|
||||||
|
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||||
|
# 实现七牛云的具体上传逻辑
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PicGoUploader(ImageUploader):
|
||||||
|
"""Chevereto API 图片上传器"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"):
|
||||||
|
"""
|
||||||
|
初始化 Chevereto 上传器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Chevereto API 密钥
|
||||||
|
api_url: Chevereto API 上传地址
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_url = api_url
|
||||||
|
|
||||||
|
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||||
|
"""
|
||||||
|
上传图片到 Chevereto 服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 图片文件二进制数据
|
||||||
|
filename: 文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UploadResponse: 上传响应对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UploadError: 上传失败时抛出异常
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 准备请求头
|
||||||
|
headers = {
|
||||||
|
"X-API-Key": self.api_key
|
||||||
|
}
|
||||||
|
|
||||||
|
# 准备文件数据
|
||||||
|
files = {
|
||||||
|
"source": (filename, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
self.api_url,
|
||||||
|
headers=headers,
|
||||||
|
files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查响应状态
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# 验证上传是否成功
|
||||||
|
if result.get("status_code") != 200:
|
||||||
|
error_message = "Upload failed"
|
||||||
|
if "error" in result:
|
||||||
|
error_message = result["error"].get("message", error_message)
|
||||||
|
raise UploadError(
|
||||||
|
message=error_message,
|
||||||
|
error_type=UploadErrorType.SERVER_ERROR,
|
||||||
|
status_code=result.get("status_code"),
|
||||||
|
details=result.get("error")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从响应中提取图片信息
|
||||||
|
image_data = result.get("image", {})
|
||||||
|
|
||||||
|
# 构建图片元数据
|
||||||
|
image_metadata = ImageMetadata(
|
||||||
|
width=image_data.get("width", 0),
|
||||||
|
height=image_data.get("height", 0),
|
||||||
|
filename=image_data.get("filename", filename),
|
||||||
|
size=image_data.get("size", 0),
|
||||||
|
url=image_data.get("url", ""),
|
||||||
|
delete_url=image_data.get("delete_url", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
return UploadResponse(
|
||||||
|
success=True,
|
||||||
|
code="success",
|
||||||
|
message=result.get("success", {}).get("message", "Upload success"),
|
||||||
|
data=image_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
# 处理网络请求相关错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Upload request failed: {str(e)}",
|
||||||
|
error_type=UploadErrorType.NETWORK_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError, TypeError) as e:
|
||||||
|
# 处理响应解析错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Invalid response format: {str(e)}",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
except UploadError:
|
||||||
|
# 重新抛出已经是 UploadError 类型的异常
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他未预期的错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Upload failed: {str(e)}",
|
||||||
|
error_type=UploadErrorType.UNKNOWN,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudFlareImgBedUploader(ImageUploader):
|
||||||
|
"""CloudFlare图床上传器"""
|
||||||
|
|
||||||
|
def __init__(self, auth_code: str, api_url: str):
|
||||||
|
"""
|
||||||
|
初始化CloudFlare图床上传器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_code: 认证码
|
||||||
|
api_url: 上传API地址
|
||||||
|
"""
|
||||||
|
self.auth_code = auth_code
|
||||||
|
self.api_url = api_url
|
||||||
|
|
||||||
|
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||||
|
"""
|
||||||
|
上传图片到CloudFlare图床
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 图片文件二进制数据
|
||||||
|
filename: 文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UploadResponse: 上传响应对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UploadError: 上传失败时抛出异常
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 准备请求URL(添加认证码参数,如果存在)
|
||||||
|
if self.auth_code:
|
||||||
|
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
|
||||||
|
else:
|
||||||
|
request_url = f"{self.api_url}?uploadNameType=origin"
|
||||||
|
|
||||||
|
# 准备文件数据
|
||||||
|
files = {
|
||||||
|
"file": (filename, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
request_url,
|
||||||
|
files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查响应状态
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# 验证响应格式
|
||||||
|
if not result or not isinstance(result, list) or len(result) == 0:
|
||||||
|
raise UploadError(
|
||||||
|
message="Invalid response format",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取文件URL
|
||||||
|
file_path = result[0].get("src")
|
||||||
|
if not file_path:
|
||||||
|
raise UploadError(
|
||||||
|
message="Missing file URL in response",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建完整URL(如果返回的是相对路径)
|
||||||
|
base_url = self.api_url.split("/upload")[0]
|
||||||
|
full_url = file_path if file_path.startswith(("http://", "https://")) else f"{base_url}{file_path}"
|
||||||
|
|
||||||
|
# 构建图片元数据(注意:CloudFlare-ImgBed不返回所有元数据,所以部分字段为默认值)
|
||||||
|
image_metadata = ImageMetadata(
|
||||||
|
width=0, # CloudFlare-ImgBed不返回宽度
|
||||||
|
height=0, # CloudFlare-ImgBed不返回高度
|
||||||
|
filename=filename,
|
||||||
|
size=0, # CloudFlare-ImgBed不返回大小
|
||||||
|
url=full_url,
|
||||||
|
delete_url=None # CloudFlare-ImgBed不返回删除URL
|
||||||
|
)
|
||||||
|
|
||||||
|
return UploadResponse(
|
||||||
|
success=True,
|
||||||
|
code="success",
|
||||||
|
message="Upload success",
|
||||||
|
data=image_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
# 处理网络请求相关错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Upload request failed: {str(e)}",
|
||||||
|
error_type=UploadErrorType.NETWORK_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError, TypeError, IndexError) as e:
|
||||||
|
# 处理响应解析错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Invalid response format: {str(e)}",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
except UploadError:
|
||||||
|
# 重新抛出已经是 UploadError 类型的异常
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他未预期的错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Upload failed: {str(e)}",
|
||||||
|
error_type=UploadErrorType.UNKNOWN,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
|
||||||
|
class ImageUploaderFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create(provider: str, **credentials) -> ImageUploader:
|
||||||
|
if provider == "smms":
|
||||||
|
return SmMsUploader(credentials["api_key"])
|
||||||
|
elif provider == "qiniu":
|
||||||
|
return QiniuUploader(
|
||||||
|
credentials["access_key"],
|
||||||
|
credentials["secret_key"]
|
||||||
|
)
|
||||||
|
elif provider == "picgo":
|
||||||
|
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
|
||||||
|
return PicGoUploader(credentials["api_key"], api_url)
|
||||||
|
elif provider == "cloudflare_imgbed":
|
||||||
|
return CloudFlareImgBedUploader(
|
||||||
|
credentials["auth_code"],
|
||||||
|
credentials["base_url"]
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
Reference in New Issue
Block a user