mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-06 20:32:47 +08:00
fix: 修复 OpenAI 和 Gemini API 调用重试逻辑及日志记录
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
.idea/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from email.header import Header
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.schemas.gemini_models import GeminiRequest
|
||||
from app.services.chat_service import ChatService
|
||||
from app.services.key_manager import KeyManager
|
||||
from app.services.model_service import ModelService
|
||||
from app.schemas.gemini_models import GeminiRequest
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_gemini_logger
|
||||
|
||||
router = APIRouter(prefix="/gemini/v1beta")
|
||||
logger = get_gemini_logger()
|
||||
@@ -19,10 +18,11 @@ key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(
|
||||
key: str = None,
|
||||
token: str = Depends(security_service.verify_key),
|
||||
key: str = None,
|
||||
token: str = Depends(security_service.verify_key),
|
||||
):
|
||||
"""获取可用的Gemini模型列表"""
|
||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||
@@ -31,17 +31,18 @@ async def list_models(
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return model_service.get_gemini_models(api_key)
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
|
||||
):
|
||||
"""非流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
retries = 0
|
||||
@@ -66,19 +67,20 @@ async def generate_content(
|
||||
if retries >= MAX_RETRIES:
|
||||
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
|
||||
):
|
||||
"""流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
|
||||
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
|
||||
try:
|
||||
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
@@ -87,6 +89,6 @@ async def stream_generate_content(
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
|
||||
@@ -48,9 +48,9 @@ async def chat_completion(
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
retries = 0
|
||||
MAX_RETRIES = 3
|
||||
max_retries = 3
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
while retries < max_retries:
|
||||
try:
|
||||
response = await chat_service.create_chat_completion(
|
||||
request=request,
|
||||
@@ -64,13 +64,13 @@ async def chat_completion(
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
|
||||
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {max_retries}"
|
||||
)
|
||||
api_key = await key_manager.handle_api_failure(api_key)
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
retries += 1
|
||||
if retries >= MAX_RETRIES:
|
||||
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached. Raising error")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ class Settings(BaseSettings):
|
||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
@@ -5,24 +5,26 @@ import platform
|
||||
|
||||
# ANSI转义序列颜色代码
|
||||
COLORS = {
|
||||
'DEBUG': '\033[34m', # 蓝色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[1;31m' # 红色加粗
|
||||
'DEBUG': '\033[34m', # 蓝色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[1;31m' # 红色加粗
|
||||
}
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == 'Windows':
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
|
||||
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""
|
||||
自定义的日志格式化器,添加颜色支持
|
||||
"""
|
||||
|
||||
def format(self, record):
|
||||
# 获取对应级别的颜色代码
|
||||
color = COLORS.get(record.levelname, '')
|
||||
@@ -30,6 +32,7 @@ class ColoredFormatter(logging.Formatter):
|
||||
record.levelname = f"{color}{record.levelname}\033[0m"
|
||||
return super().format(record)
|
||||
|
||||
|
||||
# 日志格式
|
||||
FORMATTER = ColoredFormatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||||
@@ -44,13 +47,17 @@ LOG_LEVELS = {
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@staticmethod
|
||||
def setup_logger(
|
||||
name: str,
|
||||
level: str = "debug",
|
||||
name: str,
|
||||
level: str = "debug",
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
设置并获取logger
|
||||
@@ -82,30 +89,39 @@ class Logger:
|
||||
"""
|
||||
return Logger._loggers.get(name)
|
||||
|
||||
|
||||
# 预定义的loggers
|
||||
def get_openai_logger():
|
||||
return Logger.setup_logger("openai")
|
||||
|
||||
|
||||
def get_gemini_logger():
|
||||
return Logger.setup_logger("gemini")
|
||||
|
||||
|
||||
def get_chat_logger():
|
||||
return Logger.setup_logger("chat")
|
||||
|
||||
|
||||
def get_model_logger():
|
||||
return Logger.setup_logger("model")
|
||||
|
||||
|
||||
def get_security_logger():
|
||||
return Logger.setup_logger("security")
|
||||
|
||||
|
||||
def get_key_manager_logger():
|
||||
return Logger.setup_logger("key_manager")
|
||||
return Logger.setup_logger("key_manager")
|
||||
|
||||
|
||||
def get_main_logger():
|
||||
return Logger.setup_logger("main")
|
||||
|
||||
|
||||
def get_embeddings_logger():
|
||||
return Logger.setup_logger("embeddings")
|
||||
|
||||
|
||||
def get_request_logger():
|
||||
return Logger.setup_logger("request")
|
||||
return Logger.setup_logger("request")
|
||||
|
||||
@@ -3,15 +3,15 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import json
|
||||
from app.core.logger import get_request_logger
|
||||
|
||||
|
||||
logger = get_request_logger()
|
||||
|
||||
|
||||
# 添加中间件类
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# 记录请求路径
|
||||
logger.info(f"Request path: {request.url.path}")
|
||||
|
||||
|
||||
# 获取并记录请求体
|
||||
try:
|
||||
body = await request.body()
|
||||
@@ -33,4 +33,4 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
request._receive = receive
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return response
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SafetySetting(BaseModel):
|
||||
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT"]] = None
|
||||
threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE","BLOCK_ONLY_HIGH","BLOCK_NONE","OFF"]] = None
|
||||
category: Optional[Literal[
|
||||
"HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT"]] = None
|
||||
threshold: Optional[Literal[
|
||||
"HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
@@ -36,4 +39,4 @@ class GeminiRequest(BaseModel):
|
||||
tools: Optional[List[Dict[str, Any]]] = []
|
||||
safetySettings: Optional[List[SafetySetting]] = None
|
||||
generationConfig: Optional[GenerationConfig] = None
|
||||
systemInstruction: Optional[SystemInstruction] = None
|
||||
systemInstruction: Optional[SystemInstruction] = None
|
||||
|
||||
@@ -11,61 +11,73 @@ from app.schemas.openai_models import ChatRequest
|
||||
logger = get_chat_logger()
|
||||
|
||||
|
||||
def convert_messages_to_gemini_format(messages: list) -> list:
|
||||
"""Convert OpenAI message format to Gemini format"""
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
parts = []
|
||||
|
||||
# 处理文本内容
|
||||
if isinstance(msg["content"], str):
|
||||
parts.append({"text": msg["content"]})
|
||||
# 处理包含图片的消息
|
||||
elif isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str):
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict) and content["type"] == "text":
|
||||
parts.append({"text": content["text"]})
|
||||
elif isinstance(content, dict) and content["type"] == "image_url":
|
||||
# 处理图片URL
|
||||
image_url = content["image_url"]["url"]
|
||||
if image_url.startswith("data:image"):
|
||||
# 处理base64图片
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": image_url.split(",")[1],
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 处理普通URL图片
|
||||
parts.append(
|
||||
{
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def format_execution_result(result_data: dict) -> str:
|
||||
"""格式化执行结果输出"""
|
||||
outcome = result_data.get("outcome", "")
|
||||
output = result_data.get("output", "").strip()
|
||||
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n"""
|
||||
|
||||
|
||||
def create_search_link(web):
|
||||
return f'\n- [{web["title"]}]({web["uri"]})'
|
||||
|
||||
|
||||
class ChatService:
|
||||
def __init__(self, base_url: str, key_manager=None):
|
||||
self.base_url = base_url
|
||||
self.key_manager = key_manager
|
||||
|
||||
def convert_messages_to_gemini_format(self, messages: list) -> list:
|
||||
"""Convert OpenAI message format to Gemini format"""
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
parts = []
|
||||
|
||||
# 处理文本内容
|
||||
if isinstance(msg["content"], str):
|
||||
parts.append({"text": msg["content"]})
|
||||
# 处理包含图片的消息
|
||||
elif isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str):
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict) and content["type"] == "text":
|
||||
parts.append({"text": content["text"]})
|
||||
elif isinstance(content, dict) and content["type"] == "image_url":
|
||||
# 处理图片URL
|
||||
image_url = content["image_url"]["url"]
|
||||
if image_url.startswith("data:image"):
|
||||
# 处理base64图片
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": image_url.split(",")[1],
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 处理普通URL图片
|
||||
parts.append(
|
||||
{
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return converted_messages
|
||||
|
||||
def convert_gemini_response_to_openai(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Convert Gemini response to OpenAI format"""
|
||||
if stream:
|
||||
@@ -82,28 +94,17 @@ class ChatService:
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = self.format_execution_result(
|
||||
text = format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = self.format_execution_result(
|
||||
text = format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
groundingChunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += f"**【引用来源】**\n\n"
|
||||
for _, groundingChunk in enumerate(groundingChunks, 1):
|
||||
if "web" in groundingChunk:
|
||||
text += self.create_search_link(groundingChunk["web"])
|
||||
text = self.add_search_link_text(model, candidate, text)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
@@ -131,37 +132,26 @@ class ChatService:
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0]["text"],
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0]["text"],
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
try:
|
||||
if response.get("candidates"):
|
||||
text = response["candidates"][0]["content"]["parts"][0]["text"]
|
||||
candidate = response["candidates"][0]
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
groundingChunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += f"**【引用来源】**\n\n"
|
||||
for _, groundingChunk in enumerate(groundingChunks, 1):
|
||||
if "web" in groundingChunk:
|
||||
text += self.create_search_link(groundingChunk["web"])
|
||||
text = self.add_search_link_text(model, candidate, text)
|
||||
res["choices"][0]["message"]["content"] = text
|
||||
return res
|
||||
else:
|
||||
@@ -173,10 +163,27 @@ class ChatService:
|
||||
res["choices"][0]["message"]["content"] = f"Error converting Gemini response: {str(e)}"
|
||||
return res
|
||||
|
||||
def add_search_link_text(self, model, candidate, text):
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += f"**【引用来源】**\n\n"
|
||||
for _, grounding_chunk in enumerate(grounding_chunks, 1):
|
||||
if "web" in grounding_chunk:
|
||||
text += create_search_link(grounding_chunk["web"])
|
||||
return text
|
||||
else:
|
||||
return text
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Create chat completion using either Gemini or OpenAI API"""
|
||||
model = request.model
|
||||
@@ -184,7 +191,7 @@ class ChatService:
|
||||
if tools is None:
|
||||
tools = []
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
@@ -192,10 +199,10 @@ class ChatService:
|
||||
return await self._gemini_chat_completion(request, api_key, tools)
|
||||
|
||||
async def _gemini_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Handle Gemini API chat completion"""
|
||||
model = request.model
|
||||
@@ -210,7 +217,7 @@ class ChatService:
|
||||
gemini_model = model[:-7] # Remove -search suffix
|
||||
else:
|
||||
gemini_model = model
|
||||
gemini_messages = self.convert_messages_to_gemini_format(messages)
|
||||
gemini_messages = convert_messages_to_gemini_format(messages)
|
||||
|
||||
if not stream:
|
||||
# 非流式模式下,移除代码执行工具
|
||||
@@ -247,26 +254,26 @@ class ChatService:
|
||||
if stream:
|
||||
async def generate():
|
||||
retries = 0
|
||||
MAX_RETRIES = 3
|
||||
max_retries = 3
|
||||
current_api_key = api_key
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
while retries < max_retries:
|
||||
try:
|
||||
timeout = httpx.Timeout(
|
||||
60.0, read=60.0
|
||||
) # 连接超时60秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as async_client:
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
async with client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.read()
|
||||
async with async_client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
) as async_response:
|
||||
if async_response.status_code != 200:
|
||||
error_content = await async_response.read()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
logger.error(
|
||||
f"API error: {response.status_code}, {error_msg}"
|
||||
f"API error: {async_response.status_code}, {error_msg}"
|
||||
)
|
||||
if retries < MAX_RETRIES - 1:
|
||||
if retries < max_retries - 1:
|
||||
current_api_key = (
|
||||
await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
@@ -276,12 +283,12 @@ class ChatService:
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"Max retries reached. Final error: {response.status_code}, {error_msg}"
|
||||
f"Max retries reached. Final error: {async_response.status_code}, {error_msg}"
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'API error: {response.status_code}, {error_msg}'})}\n\n"
|
||||
yield f"data: {json.dumps({'error': f'API error: {async_response.status_code}, {error_msg}'})}\n\n"
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
async for line in async_response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
@@ -297,7 +304,7 @@ class ChatService:
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
yield f"data: {json.dumps(self.convert_gemini_response_to_openai({}, model,stream=True, finish_reason='stop'))}\n\n"
|
||||
yield f"data: {json.dumps(self.convert_gemini_response_to_openai({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
@@ -305,7 +312,7 @@ class ChatService:
|
||||
logger.warning(
|
||||
f"Read timeout occurred, attempting retry {retries + 1}"
|
||||
)
|
||||
if retries < MAX_RETRIES - 1:
|
||||
if retries < max_retries - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
@@ -322,7 +329,7 @@ class ChatService:
|
||||
logger.exception(
|
||||
f"Stream error: {str(e)}, attempting retry {retries + 1}"
|
||||
)
|
||||
if retries < MAX_RETRIES - 1:
|
||||
if retries < max_retries - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
@@ -362,14 +369,8 @@ class ChatService:
|
||||
|
||||
return f"""\n【代码执行】\n```{language}\n{code}\n```\n"""
|
||||
|
||||
def format_execution_result(self, result_data: dict) -> str:
|
||||
"""格式化执行结果输出"""
|
||||
outcome = result_data.get("outcome", "")
|
||||
output = result_data.get("output", "").strip()
|
||||
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n"""
|
||||
|
||||
async def generate_content(
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
) -> dict:
|
||||
"""调用Gemini API生成内容"""
|
||||
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||
@@ -392,7 +393,7 @@ class ChatService:
|
||||
raise
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator:
|
||||
"""调用Gemini API流式生成内容"""
|
||||
retries = 0
|
||||
@@ -406,7 +407,7 @@ class ChatService:
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST", url, json=request.model_dump()
|
||||
"POST", url, json=request.model_dump()
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = await response.text()
|
||||
@@ -451,6 +452,3 @@ class ChatService:
|
||||
retries += 1
|
||||
continue
|
||||
raise
|
||||
|
||||
def create_search_link(self, web):
|
||||
return f'\n- [{web["title"]}]({web["uri"]})'
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Union, List
|
||||
|
||||
import openai
|
||||
from typing import Union, List, Dict, Any
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
|
||||
from app.core.logger import get_embeddings_logger
|
||||
|
||||
logger = get_embeddings_logger()
|
||||
@@ -11,7 +14,7 @@ class EmbeddingService:
|
||||
|
||||
async def create_embedding(
|
||||
self, input_text: Union[str, List[str]], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""Create embeddings using OpenAI API"""
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
|
||||
Reference in New Issue
Block a user