mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 06:11:32 +08:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
929592bbc4 | ||
|
|
2225a40bbe | ||
|
|
3480fa3b0f | ||
|
|
d7113f5fc4 | ||
|
|
2072f54ca1 | ||
|
|
7c9b721164 | ||
|
|
83ce50975a | ||
|
|
7da9110704 | ||
|
|
e9d19de7c6 | ||
|
|
e822831178 | ||
|
|
775930edce |
@@ -23,6 +23,9 @@ CHECK_INTERVAL_HOURS=1
|
||||
TIMEZONE=Asia/Shanghai
|
||||
# 请求超时时间(秒)
|
||||
TIME_OUT=300
|
||||
# 代理服务器配置 (支持 http 和 socks5)
|
||||
# 示例: PROXIES=["http://user:pass@host:port", "socks5://host:port"]
|
||||
PROXIES=[]
|
||||
#########################image_generate 相关配置###########################
|
||||
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
|
||||
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||
@@ -44,3 +47,7 @@ STREAM_CHUNK_SIZE=5
|
||||
# 日志级别 (debug, info, warning, error, critical),默认为 info
|
||||
LOG_LEVEL=info
|
||||
##########################################################################
|
||||
|
||||
# 安全设置 (JSON 字符串格式)
|
||||
# 注意:这里的示例值可能需要根据实际模型支持情况调整
|
||||
SAFETY_SETTINGS='[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}]'
|
||||
|
||||
16
README.md
16
README.md
@@ -67,6 +67,7 @@ app/
|
||||
>镜像地址: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **模型列表自动维护**: 支持openai和gemini模型列表获取,与newapi自动获取模型列表完美兼容,无需手动填写。
|
||||
* **支持移除不使用的模型**: 默认提供的模型太多,很多用不上,可以通过`FILTERED_MODELS`过滤掉。
|
||||
* **代理支持**: 支持配置 HTTP/SOCKS5 代理服务器 (`PROXIES`),用于访问 Gemini API,方便在特殊网络环境下使用。支持批量添加代理。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
@@ -166,6 +167,7 @@ app/
|
||||
| `CHECK_INTERVAL_HOURS` | 可选,检查禁用 Key 是否恢复的时间间隔 (小时) | `1` |
|
||||
| `TIMEZONE` | 可选,应用程序使用的时区 | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | 可选,请求超时时间 (秒) | `300` |
|
||||
| `PROXIES` | 可选,代理服务器列表 (例如 `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
|
||||
| `LOG_LEVEL` | 可选,日志级别,例如 DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 可选,付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
@@ -193,12 +195,16 @@ app/
|
||||
* `POST /models/{model_name}:generateContent`: 使用指定的 Gemini 模型生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 使用指定的 Gemini 模型流式生成内容。
|
||||
|
||||
### OpenAI API 相关 (`(/hf)/v1`)
|
||||
### OpenAI API 相关
|
||||
|
||||
* `GET /v1/models`: 列出可用的 OpenAI 模型。
|
||||
* `POST /v1/chat/completions`: 通过 OpenAI API 进行聊天补全。
|
||||
* `POST /v1/images/generations`: 通过 OpenAI API 生成图像。
|
||||
* `POST /v1/embeddings`: 通过 OpenAI API 创建文本嵌入。
|
||||
* `GET (/hf)/v1/models`: 列出可用的模型 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/chat/completions`: 进行聊天补全 (底层用的gemini格式, 支持流式传输)。
|
||||
* `POST (/hf)/v1/embeddings`: 创建文本嵌入 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/images/generations`: 生成图像 (底层用的gemini格式)。
|
||||
* `GET /openai/v1/models`: 列出可用的模型 (底层用的openai格式)。
|
||||
* `POST /openai/v1/chat/completions`: 进行聊天补全 (底层用的openai格式, 支持流式传输, 可防止截断,速度也快)。
|
||||
* `POST /openai/v1/embeddings`: 创建文本嵌入 (底层用的openai格式)。
|
||||
* `POST /openai/v1/images/generations`: 生成图像 (底层用的openai格式)。
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import ValidationError
|
||||
from pydantic_settings import BaseSettings
|
||||
from sqlalchemy import insert, update, select
|
||||
|
||||
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT, MAX_RETRIES
|
||||
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_SAFETY_SETTINGS, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT, MAX_RETRIES
|
||||
from app.log.logger import Logger
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ class Settings(BaseSettings):
|
||||
TEST_MODEL: str = DEFAULT_MODEL
|
||||
TIME_OUT: int = DEFAULT_TIMEOUT
|
||||
MAX_RETRIES: int = MAX_RETRIES
|
||||
|
||||
PROXIES: List[str] = [] # 新增:代理服务器列表
|
||||
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
@@ -68,6 +69,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = "INFO" # 默认日志级别
|
||||
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS # 新增:安全设置
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -120,6 +122,32 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
# Log other errors (ValueError, TypeError) or JSON errors without single quotes
|
||||
logger.error(f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict.")
|
||||
return parsed_dict # Return the parsed dict or an empty one if all attempts fail
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif target_type == List[Dict[str, str]]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
|
||||
valid = all(
|
||||
isinstance(item, dict) and
|
||||
all(isinstance(k, str) for k in item.keys()) and
|
||||
all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
else:
|
||||
logger.warning(f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}")
|
||||
return [] # 或者返回默认值?这里返回空列表
|
||||
else:
|
||||
logger.warning(f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}")
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list.")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list.")
|
||||
return []
|
||||
# 处理 bool
|
||||
elif target_type == bool:
|
||||
return db_value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
@@ -40,3 +40,40 @@ DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||
# 正则表达式模式
|
||||
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||
|
||||
# Audio/Video Settings
|
||||
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
||||
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
||||
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
|
||||
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
||||
AUDIO_FORMAT_TO_MIMETYPE = {
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"ogg": "audio/ogg",
|
||||
}
|
||||
|
||||
VIDEO_FORMAT_TO_MIMETYPE = {
|
||||
"mp4": "video/mp4",
|
||||
"mov": "video/quicktime",
|
||||
"avi": "video/x-msvideo",
|
||||
"webm": "video/webm",
|
||||
}
|
||||
|
||||
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
|
||||
DEFAULT_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||
|
||||
@@ -9,11 +9,14 @@ class ChatRequest(BaseModel):
|
||||
model: str = DEFAULT_MODEL
|
||||
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = DEFAULT_TOP_P
|
||||
top_k: Optional[int] = DEFAULT_TOP_K
|
||||
stop: Optional[List[str]] = []
|
||||
stop: Optional[Union[List[str],str]] = None
|
||||
reasoning_effort: Optional[str] = None
|
||||
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||
tool_choice: Optional[str] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
@@ -23,10 +26,10 @@ class EmbeddingRequest(BaseModel):
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = "DALL-E-3"
|
||||
model: str = "imagen-3.0-generate-002"
|
||||
prompt: str = ""
|
||||
n: int = 1
|
||||
size: Optional[str] = "1024x1024"
|
||||
quality: Optional[str] = ""
|
||||
style: Optional[str] = ""
|
||||
response_format: Optional[str] = "url"
|
||||
quality: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
response_format: Optional[str] = "b64_json"
|
||||
|
||||
32
app/handler/error_handler.py
Normal file
32
app/handler/error_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
|
||||
@asynccontextmanager
|
||||
async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
|
||||
"""
|
||||
一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
|
||||
|
||||
Args:
|
||||
logger: 用于记录日志的 Logger 实例。
|
||||
operation_name: 操作的名称,用于日志记录和错误详情。
|
||||
success_message: 操作成功时记录的自定义消息 (可选)。
|
||||
failure_message: 操作失败时记录的自定义消息 (可选)。
|
||||
"""
|
||||
default_success_msg = f"{operation_name} request successful"
|
||||
default_failure_msg = f"{operation_name} request failed"
|
||||
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
try:
|
||||
yield
|
||||
logger.info(success_message or default_success_msg)
|
||||
except HTTPException as http_exc:
|
||||
# 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情
|
||||
logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
# 对于其他所有异常,记录错误并抛出标准的 500 错误
|
||||
logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal server error during {operation_name}"
|
||||
) from e
|
||||
@@ -1,61 +1,70 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
import base64
|
||||
|
||||
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
|
||||
import requests
|
||||
|
||||
from app.core.constants import (
|
||||
AUDIO_FORMAT_TO_MIMETYPE,
|
||||
DATA_URL_PATTERN,
|
||||
IMAGE_URL_PATTERN,
|
||||
MAX_AUDIO_SIZE_BYTES,
|
||||
MAX_VIDEO_SIZE_BYTES,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
SUPPORTED_ROLES,
|
||||
SUPPORTED_VIDEO_FORMATS,
|
||||
VIDEO_FORMAT_TO_MIMETYPE,
|
||||
)
|
||||
from app.log.logger import get_message_converter_logger
|
||||
|
||||
logger = get_message_converter_logger()
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""消息转换器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
|
||||
def _get_mime_type_and_data(base64_string):
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据。
|
||||
|
||||
|
||||
参数:
|
||||
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
|
||||
返回:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith('data:'):
|
||||
if base64_string.startswith("data:"):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = DATA_URL_PATTERN
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
mime_type = (
|
||||
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
|
||||
def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||
if image_url.startswith("data:image"):
|
||||
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
||||
return {
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": encoded_data
|
||||
}
|
||||
}
|
||||
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
|
||||
else:
|
||||
encoded_data = _convert_image_to_base64(image_url)
|
||||
return {
|
||||
"inline_data": {
|
||||
"mime_type": "image/png",
|
||||
"data": encoded_data
|
||||
}
|
||||
}
|
||||
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
|
||||
|
||||
|
||||
def _convert_image_to_base64(url: str) -> str:
|
||||
@@ -69,7 +78,7 @@ def _convert_image_to_base64(url: str) -> str:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||
img_data = base64.b64encode(response.content).decode("utf-8")
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
@@ -93,12 +102,9 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
# 将URL对应的图片转换为base64
|
||||
try:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append({
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": base64_data
|
||||
}
|
||||
})
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
except Exception:
|
||||
# 如果转换失败,回退到文本模式
|
||||
parts.append({"text": text})
|
||||
@@ -111,42 +117,215 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
class OpenAIMessageConverter(MessageConverter):
|
||||
"""OpenAI消息格式转换器"""
|
||||
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
def _validate_media_data(
|
||||
self, format: str, data: str, supported_formats: List[str], max_size: int
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Validates format and size of Base64 media data."""
|
||||
if format.lower() not in supported_formats:
|
||||
logger.error(
|
||||
f"Unsupported media format: {format}. Supported: {supported_formats}"
|
||||
)
|
||||
raise ValueError(f"Unsupported media format: {format}")
|
||||
|
||||
try:
|
||||
# Decode Base64 to check size
|
||||
# Be careful with memory usage for very large files
|
||||
# Consider streaming decoding or checking length heuristic first if memory is a concern
|
||||
decoded_data = base64.b64decode(
|
||||
data, validate=True
|
||||
) # Use validate=True for stricter check
|
||||
if len(decoded_data) > max_size:
|
||||
logger.error(
|
||||
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
|
||||
)
|
||||
# No need to return decoded_data, just the original base64 if valid
|
||||
return data
|
||||
except base64.binascii.Error as e:
|
||||
logger.error(f"Invalid Base64 data provided: {e}")
|
||||
raise ValueError("Invalid Base64 data")
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating media data: {e}")
|
||||
raise
|
||||
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction_parts = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
|
||||
parts = []
|
||||
# 特别处理最后一个assistant的消息,按\n\n分割
|
||||
if "content" in msg and isinstance(msg["content"], str) and msg["content"] and role == "assistant" and idx == len(messages) - 2:
|
||||
# 按\n\n分割消息
|
||||
content_parts = msg["content"].split("\n\n")
|
||||
for part in content_parts:
|
||||
if not part.strip(): # 跳过空内容
|
||||
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
for content_item in msg["content"]:
|
||||
if not isinstance(content_item, dict):
|
||||
# Skip non-dict items if any unexpected format appears
|
||||
logger.warning(
|
||||
f"Skipping unexpected content item format: {type(content_item)}"
|
||||
)
|
||||
continue
|
||||
# 处理可能包含图片的文本
|
||||
parts.extend(_process_text_with_image(part))
|
||||
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
|
||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||
|
||||
content_type = content_item.get("type")
|
||||
|
||||
if content_type == "text" and content_item.get("text"):
|
||||
parts.append({"text": content_item["text"]})
|
||||
elif content_type == "image_url" and content_item.get(
|
||||
"image_url", {}
|
||||
).get("url"):
|
||||
try:
|
||||
parts.append(
|
||||
_convert_image(content_item["image_url"]["url"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
|
||||
)
|
||||
# Decide how to handle: skip part, add error text, etc.
|
||||
parts.append(
|
||||
{
|
||||
"text": f"[Error processing image: {content_item['image_url']['url']}]"
|
||||
}
|
||||
)
|
||||
# --- Add handling for input_audio ---
|
||||
elif content_type == "input_audio" and content_item.get(
|
||||
"input_audio"
|
||||
):
|
||||
audio_info = content_item["input_audio"]
|
||||
audio_data = audio_info.get("data")
|
||||
audio_format = audio_info.get("format", "").lower()
|
||||
|
||||
if not audio_data or not audio_format:
|
||||
logger.warning(
|
||||
"Skipping audio part due to missing data or format."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate size and format
|
||||
validated_data = self._validate_media_data(
|
||||
audio_format,
|
||||
audio_data,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
MAX_AUDIO_SIZE_BYTES,
|
||||
)
|
||||
|
||||
# Get MIME type
|
||||
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
|
||||
if not mime_type:
|
||||
# Should not happen if format validation passed, but double-check
|
||||
logger.error(
|
||||
f"Could not find MIME type for supported format: {audio_format}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Internal error: MIME type mapping missing for {audio_format}"
|
||||
)
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": mime_type,
|
||||
"data": validated_data, # Use the validated Base64 data
|
||||
}
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Successfully added audio part (format: {audio_format})"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Skipping audio part due to validation error: {e}"
|
||||
)
|
||||
parts.append({"text": f"[Error processing audio: {e}]"})
|
||||
except Exception:
|
||||
logger.exception("Unexpected error processing audio part.")
|
||||
parts.append(
|
||||
{"text": "[Unexpected error processing audio]"}
|
||||
)
|
||||
|
||||
elif content_type == "input_video" and content_item.get(
|
||||
"input_video"
|
||||
):
|
||||
video_info = content_item["input_video"]
|
||||
video_data = video_info.get("data")
|
||||
video_format = video_info.get("format", "").lower()
|
||||
|
||||
if not video_data or not video_format:
|
||||
logger.warning(
|
||||
"Skipping video part due to missing data or format."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
validated_data = self._validate_media_data(
|
||||
video_format,
|
||||
video_data,
|
||||
SUPPORTED_VIDEO_FORMATS,
|
||||
MAX_VIDEO_SIZE_BYTES,
|
||||
)
|
||||
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Internal error: MIME type mapping missing for {video_format}"
|
||||
)
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": mime_type,
|
||||
"data": validated_data,
|
||||
}
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Successfully added video part (format: {video_format})"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Skipping video part due to validation error: {e}"
|
||||
)
|
||||
parts.append({"text": f"[Error processing video: {e}]"})
|
||||
except Exception:
|
||||
logger.exception("Unexpected error processing video part.")
|
||||
parts.append(
|
||||
{"text": "[Unexpected error processing video]"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Log unrecognized but present types
|
||||
if content_type:
|
||||
logger.warning(
|
||||
f"Unsupported content type or missing data in structured content: {content_type}"
|
||||
)
|
||||
|
||||
elif (
|
||||
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
||||
):
|
||||
parts.extend(_process_text_with_image(msg["content"]))
|
||||
elif "content" in msg and isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict):
|
||||
if content["type"] == "text" and content["text"]:
|
||||
parts.append({"text": content["text"]})
|
||||
elif content["type"] == "image_url":
|
||||
parts.append(_convert_image(content["image_url"]["url"]))
|
||||
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
# Keep existing tool call processing
|
||||
for tool_call in msg["tool_calls"]:
|
||||
function_call = tool_call.get("function",{})
|
||||
function_call["args"] = json.loads(function_call.get("arguments","{}"))
|
||||
del function_call["arguments"]
|
||||
function_call = tool_call.get("function", {})
|
||||
# Sanitize arguments loading
|
||||
arguments_str = function_call.get("arguments", "{}")
|
||||
try:
|
||||
function_call["args"] = json.loads(arguments_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments: {arguments_str}"
|
||||
)
|
||||
function_call["args"] = {}
|
||||
if "arguments" in function_call:
|
||||
if "arguments" in function_call:
|
||||
del function_call["arguments"]
|
||||
|
||||
parts.append({"functionCall": function_call})
|
||||
|
||||
|
||||
if role not in SUPPORTED_ROLES:
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
@@ -158,7 +337,14 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
role = "model"
|
||||
if parts:
|
||||
if role == "system":
|
||||
system_instruction_parts.extend(parts)
|
||||
text_only_parts = [p for p in parts if "text" in p]
|
||||
if len(text_only_parts) != len(parts):
|
||||
logger.warning(
|
||||
"Non-text parts found in system message; discarding them."
|
||||
)
|
||||
if text_only_parts:
|
||||
system_instruction_parts.extend(text_only_parts)
|
||||
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
@@ -170,4 +356,4 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
"parts": system_instruction_parts,
|
||||
}
|
||||
)
|
||||
return converted_messages, system_instruction
|
||||
return converted_messages, system_instruction
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
@@ -15,7 +15,9 @@ class ResponseHandler(ABC):
|
||||
"""响应处理器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -26,14 +28,20 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
self.thinking_first = True
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
if stream:
|
||||
return _handle_gemini_stream_response(response, model, stream)
|
||||
return _handle_gemini_normal_response(response, model, stream)
|
||||
|
||||
|
||||
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
|
||||
def _handle_openai_stream_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=True, gemini_format=False
|
||||
)
|
||||
if not text and not tool_calls:
|
||||
delta = {}
|
||||
else:
|
||||
@@ -50,8 +58,12 @@ def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
|
||||
def _handle_openai_normal_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=False, gemini_format=False
|
||||
)
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
@@ -60,7 +72,11 @@ def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
@@ -77,59 +93,67 @@ class OpenAIResponseHandler(ResponseHandler):
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason)
|
||||
return _handle_openai_normal_response(response, model, finish_reason)
|
||||
|
||||
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
|
||||
|
||||
def handle_image_chat_response(
|
||||
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
||||
):
|
||||
if stream:
|
||||
return _handle_openai_stream_image_response(image_str,model,finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str,model,finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return _handle_openai_stream_image_response(image_str, model, finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str, model, finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(
|
||||
image_str: str, model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
def _handle_openai_normal_image_response(
|
||||
image_str: str, model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": image_str
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": image_str},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
|
||||
def _extract_result(
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, List[Dict[str, Any]]]:
|
||||
text, tool_calls = "", []
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
@@ -145,13 +169,9 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = _format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
text = _format_execution_result(parts[0]["executableCodeResult"])
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
text = _format_execution_result(parts[0]["codeExecutionResult"])
|
||||
elif "inlineData" in parts[0]:
|
||||
text = _extract_image_data(parts[0])
|
||||
else:
|
||||
@@ -165,10 +185,10 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = (
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
@@ -186,34 +206,47 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
||||
tool_calls = _extract_tool_calls(
|
||||
candidate["content"]["parts"], gemini_format
|
||||
)
|
||||
else:
|
||||
text = "暂无返回"
|
||||
return text, tool_calls
|
||||
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
image_uploader = None
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
base64_data = part["inlineData"]["data"]
|
||||
#将base64_data转成bytes数组
|
||||
# 将base64_data转成bytes数组
|
||||
bytes_data = base64.b64decode(base64_data)
|
||||
upload_response = image_uploader.upload(bytes_data,filename)
|
||||
upload_response = image_uploader.upload(bytes_data, filename)
|
||||
if upload_response.success:
|
||||
text = f"\n\n\n\n"
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
||||
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def _extract_tool_calls(
|
||||
parts: List[Dict[str, Any]], gemini_format: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""提取工具调用信息"""
|
||||
if not parts or not isinstance(parts, list):
|
||||
return []
|
||||
@@ -249,8 +282,12 @@ def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> Lis
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
def _handle_gemini_stream_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
@@ -259,8 +296,12 @@ def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream:
|
||||
return response
|
||||
|
||||
|
||||
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
def _handle_gemini_normal_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
@@ -278,10 +319,10 @@ def _format_code_block(code_data: dict) -> str:
|
||||
|
||||
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
|
||||
@@ -206,3 +206,15 @@ def get_update_logger():
|
||||
|
||||
def get_scheduler_routes():
|
||||
return Logger.setup_logger("scheduler_routes")
|
||||
|
||||
|
||||
def get_message_converter_logger():
|
||||
return Logger.setup_logger("message_converter")
|
||||
|
||||
|
||||
def get_api_client_logger():
|
||||
return Logger.setup_logger("api_client")
|
||||
|
||||
|
||||
def get_openai_compatible_logger():
|
||||
return Logger.setup_logger("openai_compatible")
|
||||
@@ -30,6 +30,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
and not request.url.path.startswith(f"/{API_VERSION}")
|
||||
and not request.url.path.startswith("/health")
|
||||
and not request.url.path.startswith("/hf")
|
||||
and not request.url.path.startswith("/openai")
|
||||
and not request.url.path.startswith("/api/version/check")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
import asyncio # 导入 asyncio
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest # 添加导入
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
|
||||
# 路由设置
|
||||
@@ -43,62 +44,57 @@ async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""获取可用的Gemini模型列表"""
|
||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
logger.info("Handling Gemini models list request")
|
||||
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_json = model_service.get_gemini_models(api_key)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||
|
||||
# 添加搜索模型
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
model = model_mapping.get(name)
|
||||
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_data =await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
|
||||
models_json = deepcopy(models_data) # 操作副本以防修改原始缓存
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-search"
|
||||
display_name = f'{item.get("displayName")} For Search'
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加图像生成模型
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
model = model_mapping.get(name)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-image"
|
||||
display_name = f'{item.get("displayName")} For Image'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加思考模型的非思考版本
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
model = model_mapping.get(name)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-non-thinking"
|
||||
display_name = f'{item.get("displayName")} Non Thinking'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
return models_json
|
||||
|
||||
# 添加衍生模型
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
|
||||
logger.info("Gemini models list request successful")
|
||||
return models_json
|
||||
except HTTPException as http_exc:
|
||||
# 重新抛出已知的 HTTP 异常
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
@@ -112,25 +108,22 @@ async def generate_content(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""非流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
@@ -144,25 +137,24 @@ async def stream_generate_content(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
# 流式请求的成功/失败日志在流处理中更复杂,这里仅用上下文管理器处理启动错误
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
# 注意:流本身的错误需要在服务层或流迭代中处理,这里只返回流响应
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Streaming request failed") from e
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
|
||||
121
app/router/openai_compatiable_routes.py
Normal file
121
app/router/openai_compatiable_routes.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService()
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_next_working_key_wrapper(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
|
||||
async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取OpenAI聊天服务实例"""
|
||||
return OpenAICompatiableService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/openai/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""获取可用模型列表。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await openai_service.get_models(api_key)
|
||||
|
||||
|
||||
@router.post("/openai/v1/chat/completions")
|
||||
@RetryHandler(max_retries=settings.MAX_RETRIES, key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
# 检查是否为图像生成相关的聊天模型,如果是,则使用付费密钥
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key # 保存原始key(可能是普通key)
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key() # 获取付费密钥
|
||||
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}") # 使用 current_api_key
|
||||
|
||||
if is_image_chat:
|
||||
# 图像生成聊天,调用特定服务,不处理流式
|
||||
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
||||
return response # 直接返回结果
|
||||
else:
|
||||
# 普通聊天补全
|
||||
response = await openai_service.create_chat_completion(request, current_api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
# 假设 openai_service.create_chat_completion 在流式时返回异步生成器
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
# 非流式直接返回结果
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/openai/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
# 强制使用配置的模型,确保请求中包含正确的模型信息
|
||||
request.model = settings.CREATE_IMAGE_MODEL
|
||||
return await openai_service.generate_images(request)
|
||||
|
||||
|
||||
@router.post("/openai/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理文本嵌入请求。"""
|
||||
operation_name = "embedding"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await openai_service.create_embeddings(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
@@ -9,6 +9,7 @@ from app.domain.openai_models import (
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors # 导入共享错误处理器
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
@@ -47,17 +48,13 @@ async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
logger.info("-" * 50 + "list_models" + "-" * 50)
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
return model_service.get_gemini_openai_models(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching models list"
|
||||
) from e
|
||||
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await model_service.get_gemini_openai_models(api_key)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@@ -70,33 +67,38 @@ async def chat_completion(
|
||||
key_manager: KeyManager = Depends(get_key_manager), # 保留 key_manager 用于获取 paid_key
|
||||
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
||||
):
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
api_key = await key_manager.get_paid_key()
|
||||
logger.info("-" * 50 + "chat_completion" + "-" * 50)
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
# 检查是否为图像生成相关的聊天模型
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key # 保存原始 key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key() # 获取付费密钥
|
||||
|
||||
if not model_service.check_model_support(request.model):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
try:
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
response = await chat_service.create_image_chat_completion(request, api_key)
|
||||
# 检查模型支持性应在错误处理块内,以便捕获并记录错误
|
||||
if not await model_service.check_model_support(request.model):
|
||||
# 使用 HTTPException,会被 handle_route_errors 捕获并记录
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
if is_image_chat:
|
||||
# 图像生成聊天
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
return response # 直接返回,不处理流式
|
||||
else:
|
||||
response = await chat_service.create_chat_completion(request, api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
logger.info("Chat completion request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
# 普通聊天补全
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
# 非流式直接返回结果
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@@ -105,18 +107,14 @@ async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info("-" * 50 + "generate_image" + "-" * 50)
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
|
||||
try:
|
||||
"""处理 OpenAI 图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
# 注意:这里假设 image_create_service.generate_images 是同步函数
|
||||
# 如果它是异步的,需要 await
|
||||
response = image_create_service.generate_images(request)
|
||||
logger.info("Image generation request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation request failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Image generation request failed"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@@ -126,19 +124,16 @@ async def embedding(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
logger.info("-" * 50 + "embedding" + "-" * 50)
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
"""处理 OpenAI 文本嵌入请求。"""
|
||||
operation_name = "embedding"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
logger.info("Embedding request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
||||
|
||||
|
||||
@router.get("/v1/keys/list")
|
||||
@@ -147,10 +142,10 @@ async def get_keys_list(
|
||||
_=Depends(security_service.verify_auth_token),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取有效和无效的API key列表"""
|
||||
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
||||
logger.info("Handling keys list request")
|
||||
try:
|
||||
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
|
||||
operation_name = "get_keys_list"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling keys list request")
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -160,8 +155,3 @@ async def get_keys_list(
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching keys list"
|
||||
) from e
|
||||
|
||||
@@ -8,9 +8,9 @@ from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes # 新增导入 version_routes
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats_service import StatsService
|
||||
from app.service.stats.stats_service import StatsService
|
||||
|
||||
logger = get_routes_logger()
|
||||
|
||||
@@ -31,9 +31,10 @@ def setup_routers(app: FastAPI) -> None:
|
||||
app.include_router(gemini_routes.router_v1beta)
|
||||
app.include_router(config_routes.router)
|
||||
app.include_router(error_log_routes.router)
|
||||
app.include_router(scheduler_routes.router) # 新增包含 scheduler 路由
|
||||
app.include_router(stats_routes.router) # 包含 stats API 路由
|
||||
app.include_router(version_routes.router) # 包含 version API 路由
|
||||
app.include_router(scheduler_routes.router)
|
||||
app.include_router(stats_routes.router)
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
|
||||
# 添加页面路由
|
||||
setup_page_routes(app)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from starlette import status
|
||||
from app.core.security import verify_auth_token
|
||||
from app.service.stats_service import StatsService
|
||||
from app.service.stats.stats_service import StatsService
|
||||
from app.log.logger import get_stats_logger
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
@@ -81,10 +81,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import datetime # Add datetime import
|
||||
import time # Add time import
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.handler.message_converter import OpenAIMessageConverter
|
||||
from app.handler.response_handler import OpenAIResponseHandler
|
||||
@@ -16,17 +20,16 @@ from app.log.logger import get_openai_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log # Import add_request_log
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片部分"""
|
||||
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
if content and "parts" in content and isinstance(content["parts"], list):
|
||||
for part in content["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
if isinstance(part, dict) and "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -46,9 +49,13 @@ def _build_tools(
|
||||
or model.endswith("-image")
|
||||
or model.endswith("-image-generation")
|
||||
)
|
||||
and not _has_image_parts(messages)
|
||||
and not _has_media_parts(messages) # Use the updated check
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
logger.debug("Code execution tool enabled.")
|
||||
elif _has_media_parts(messages):
|
||||
logger.debug("Code execution tool disabled due to media parts presence.")
|
||||
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
@@ -62,7 +69,9 @@ def _build_tools(
|
||||
if item.get("type", "") == "function" and item.get("function"):
|
||||
function = deepcopy(item.get("function"))
|
||||
parameters = function.get("parameters", {})
|
||||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||||
if parameters.get("type") == "object" and not parameters.get(
|
||||
"properties", {}
|
||||
):
|
||||
function.pop("parameters", None)
|
||||
|
||||
function_declarations.append(function)
|
||||
@@ -93,20 +102,8 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
# and "gemini-2.0-pro-exp" not in model
|
||||
# ):
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
return settings.GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _build_payload(
|
||||
@@ -131,9 +128,11 @@ def _build_payload(
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
if request.model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if request.model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model,1000)}
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
|
||||
}
|
||||
|
||||
if (
|
||||
instruction
|
||||
@@ -205,7 +204,7 @@ class OpenAIChatService:
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 on success
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(
|
||||
response, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
@@ -218,17 +217,17 @@ class OpenAIChatService:
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default if parsing fails
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key, # Note: Parameter name is gemini_key in add_error_log
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=payload,
|
||||
)
|
||||
raise e # Re-throw exception
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -238,7 +237,7 @@ class OpenAIChatService:
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
@@ -261,6 +260,7 @@ class OpenAIChatService:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, current_attempt_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = self.response_handler.handle_response(
|
||||
@@ -293,7 +293,7 @@ class OpenAIChatService:
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 on success
|
||||
status_code = 200
|
||||
break # 成功后退出循环
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
@@ -307,7 +307,7 @@ class OpenAIChatService:
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default if parsing fails
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
@@ -316,38 +316,40 @@ class OpenAIChatService:
|
||||
error_type="openai-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=payload,
|
||||
)
|
||||
|
||||
# Attempt to switch API Key
|
||||
# Ensure key_manager is available (might need adjustment if not always passed)
|
||||
if self.key_manager:
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break # Exit loop if no key available
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break # Exit loop if key manager is missing
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
)
|
||||
break # Exit loop after max retries
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
finally:
|
||||
# Log the final outcome of the streaming request
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key, # Log the last key used
|
||||
is_success=is_success, # Log the final success status
|
||||
status_code=status_code, # Log the last known status code
|
||||
latency_ms=latency_ms, # Log total time including retries
|
||||
request_time=request_datetime
|
||||
api_key=final_api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
# If the loop finished due to failure, yield error and DONE
|
||||
if not is_success and retries >= max_retries:
|
||||
@@ -355,9 +357,7 @@ class OpenAIChatService:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str
|
||||
self, request: ChatRequest, api_key: str
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
|
||||
image_generate_request = ImageGenerationRequest()
|
||||
@@ -367,18 +367,22 @@ class OpenAIChatService:
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_image_completion(request.model, image_res, api_key)
|
||||
return self._handle_stream_image_completion(
|
||||
request.model, image_res, api_key
|
||||
)
|
||||
else:
|
||||
return await self._handle_normal_image_completion(request.model, image_res, api_key)
|
||||
return await self._handle_normal_image_completion(
|
||||
request.model, image_res, api_key
|
||||
)
|
||||
|
||||
async def _handle_stream_image_completion(
|
||||
self, model: str, image_data: str, api_key:str
|
||||
self, model: str, image_data: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
logger.info(f"Starting stream image completion for model: {model}")
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now() # Although not used for DB log here
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None # Although not used for DB log here
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
if image_data:
|
||||
@@ -402,7 +406,9 @@ class OpenAIChatService:
|
||||
# 如果没有文本内容(如图片URL等),整块输出
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
logger.info(f"Stream image completion finished successfully for model: {model}")
|
||||
logger.info(
|
||||
f"Stream image completion finished successfully for model: {model}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -410,46 +416,51 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500 # Default error code
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-stream", # Specific error type
|
||||
error_type="openai-image-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
|
||||
request_msg={
|
||||
"image_data_truncated": image_data[:1000]
|
||||
},
|
||||
)
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n" # Send error to client
|
||||
yield "data: [DONE]\n\n" # Still need DONE message
|
||||
# Re-raising might break the stream, decide if needed
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}")
|
||||
logger.info(
|
||||
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_normal_image_completion(
|
||||
self, model: str, image_data: str, api_key: str # Add api_key parameter
|
||||
self, model: str, image_data: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"Starting normal image completion for model: {model}")
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now() # Although not used for DB log here
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None # Although not used for DB log here
|
||||
status_code = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
result = self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
logger.info(f"Normal image completion finished successfully for model: {model}")
|
||||
logger.info(
|
||||
f"Normal image completion finished successfully for model: {model}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return result
|
||||
@@ -457,26 +468,30 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500 # Default error code
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-non-stream", # Specific error type
|
||||
error_type="openai-image-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
|
||||
request_msg={
|
||||
"image_data_truncated": image_data[:1000]
|
||||
},
|
||||
)
|
||||
# Re-raise the exception so the caller knows about the failure
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}")
|
||||
logger.info(
|
||||
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# app/services/chat/api_client.py
|
||||
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
from typing import Dict, Any, AsyncGenerator, Optional
|
||||
import httpx
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_api_client_logger
|
||||
from app.core.constants import DEFAULT_TIMEOUT
|
||||
|
||||
logger = get_api_client_logger()
|
||||
|
||||
class ApiClient(ABC):
|
||||
"""API客户端基类"""
|
||||
@@ -37,11 +40,41 @@ class GeminiApiClient(ApiClient):
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取可用的 Gemini 模型列表"""
|
||||
timeout = httpx.Timeout(timeout=5)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models?key={api_key}"
|
||||
try:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status() # 如果状态码不是 2xx,则引发 HTTPStatusError
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败: {e.response.status_code}")
|
||||
logger.error(e.response.text)
|
||||
# 返回 None 而不是抛出异常,以便上层处理
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求模型列表失败: {e}")
|
||||
# 返回 None 而不是抛出异常
|
||||
return None
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
@@ -53,7 +86,12 @@ class GeminiApiClient(ApiClient):
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
@@ -62,3 +100,96 @@ class GeminiApiClient(ApiClient):
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/openai/models"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/embeddings"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"input": input,
|
||||
"model": model,
|
||||
}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/images/generations"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
@@ -1,50 +1,47 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_model_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
|
||||
logger = get_model_logger()
|
||||
|
||||
|
||||
class ModelService:
|
||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
url = f"{settings.BASE_URL}/models?key={api_key}"
|
||||
async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""使用 GeminiApiClient 获取并过滤模型列表"""
|
||||
api_client = GeminiApiClient(base_url=settings.BASE_URL) # 实例化客户端
|
||||
gemini_models = await api_client.get_models(api_key)
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
|
||||
filtered_models_list = []
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
if model_id not in settings.FILTERED_MODELS:
|
||||
filtered_models_list.append(model)
|
||||
else:
|
||||
logger.debug(f"Filtered out model: {model_id}")
|
||||
|
||||
gemini_models["models"] = filtered_models_list
|
||||
return gemini_models
|
||||
else:
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(response.text)
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
if gemini_models is None:
|
||||
logger.error("从 API 客户端获取模型列表失败。")
|
||||
return None
|
||||
|
||||
def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
gemini_models = self.get_gemini_models(api_key)
|
||||
return self.convert_to_openai_models_format(gemini_models)
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
filtered_models_list = []
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
if model_id not in settings.FILTERED_MODELS:
|
||||
filtered_models_list.append(model)
|
||||
else:
|
||||
logger.debug(f"Filtered out model: {model_id}")
|
||||
|
||||
gemini_models["models"] = filtered_models_list
|
||||
return gemini_models
|
||||
except Exception as e:
|
||||
logger.error(f"处理模型列表时出错: {e}")
|
||||
return None
|
||||
|
||||
def convert_to_openai_models_format(
|
||||
async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取 Gemini 模型并转换为 OpenAI 格式"""
|
||||
gemini_models = await self.get_gemini_models(api_key)
|
||||
if gemini_models is None:
|
||||
return None
|
||||
|
||||
return await self.convert_to_openai_models_format(gemini_models)
|
||||
|
||||
async def convert_to_openai_models_format(
|
||||
self, gemini_models: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
openai_format = {"object": "list", "data": [], "success": True}
|
||||
@@ -81,7 +78,7 @@ class ModelService:
|
||||
openai_format["data"].append(image_model)
|
||||
return openai_format
|
||||
|
||||
def check_model_support(self, model: str) -> bool:
|
||||
async def check_model_support(self, model: str) -> bool:
|
||||
if not model or not isinstance(model, str):
|
||||
return False
|
||||
|
||||
|
||||
197
app/service/openai_compatiable/openai_compatiable_service.py
Normal file
197
app/service/openai_compatiable/openai_compatiable_service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.service.client.api_client import OpenaiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
class OpenAICompatiableService:
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.key_manager = key_manager
|
||||
self.base_url = base_url
|
||||
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
return await self.api_client.get_models(api_key)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, request_dict, api_key)
|
||||
return await self._handle_normal_completion(request.model, request_dict, api_key)
|
||||
|
||||
async def generate_images(
|
||||
self,
|
||||
request: ImageGenerationRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成图片"""
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
api_key = settings.PAID_KEY
|
||||
return await self.api_client.generate_images(request_dict, api_key)
|
||||
|
||||
async def create_embeddings(
|
||||
self,
|
||||
input_text: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建嵌入"""
|
||||
return await self.api_client.create_embeddings(input_text, model, api_key)
|
||||
|
||||
async def _handle_normal_completion(
|
||||
self, model: str, request: dict, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
try:
|
||||
response = await self.api_client.generate_content(request, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
# Try to parse status code from exception
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-compatiable-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: dict, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
|
||||
while retries < max_retries:
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, current_attempt_key
|
||||
):
|
||||
if line.startswith("data:"):
|
||||
# print(line)
|
||||
yield line + "\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
is_success = True
|
||||
status_code = 200
|
||||
break # 成功后退出循环
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
# Parse error code for logging
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-compatiable-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
|
||||
# Attempt to switch API Key
|
||||
# Ensure key_manager is available (might need adjustment if not always passed)
|
||||
if self.key_manager:
|
||||
api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
finally:
|
||||
# Log the final outcome of the streaming request
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
# If the loop finished due to failure, yield error and DONE
|
||||
if not is_success and retries >= max_retries:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
// 将需要在外部函数访问的 DOM 元素移到外部
|
||||
const safetySettingsContainer = document.getElementById('SAFETY_SETTINGS_container');
|
||||
const thinkingModelsContainer = document.getElementById('THINKING_MODELS_container');
|
||||
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// 初始化配置
|
||||
initConfig();
|
||||
@@ -63,13 +67,29 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
const cancelBulkDeleteApiKeyBtn = document.getElementById('cancelBulkDeleteApiKeyBtn'); // 新增
|
||||
const confirmBulkDeleteApiKeyBtn = document.getElementById('confirmBulkDeleteApiKeyBtn'); // 新增
|
||||
const bulkDeleteApiKeyInput = document.getElementById('bulkDeleteApiKeyInput'); // 新增
|
||||
|
||||
|
||||
// --- 新增:Proxy 模态框相关 ---
|
||||
const proxyModal = document.getElementById('proxyModal');
|
||||
const addProxyBtn = document.getElementById('addProxyBtn'); // Changed from bulkAddProxyBtn
|
||||
const closeProxyModalBtn = document.getElementById('closeProxyModalBtn');
|
||||
const cancelAddProxyBtn = document.getElementById('cancelAddProxyBtn');
|
||||
const confirmAddProxyBtn = document.getElementById('confirmAddProxyBtn');
|
||||
const proxyBulkInput = document.getElementById('proxyBulkInput');
|
||||
const bulkDeleteProxyBtn = document.getElementById('bulkDeleteProxyBtn'); // 新增
|
||||
const bulkDeleteProxyModal = document.getElementById('bulkDeleteProxyModal'); // 新增
|
||||
const closeBulkDeleteProxyModalBtn = document.getElementById('closeBulkDeleteProxyModalBtn'); // 新增
|
||||
const cancelBulkDeleteProxyBtn = document.getElementById('cancelBulkDeleteProxyBtn'); // 新增
|
||||
const confirmBulkDeleteProxyBtn = document.getElementById('confirmBulkDeleteProxyBtn'); // 新增
|
||||
const bulkDeleteProxyInput = document.getElementById('bulkDeleteProxyInput'); // 新增
|
||||
// --- 结束:Proxy 模态框相关 ---
|
||||
|
||||
// --- 新增:重置确认模态框相关 ---
|
||||
const resetConfirmModal = document.getElementById('resetConfirmModal');
|
||||
const closeResetModalBtn = document.getElementById('closeResetModalBtn');
|
||||
const cancelResetBtn = document.getElementById('cancelResetBtn');
|
||||
const confirmResetBtn = document.getElementById('confirmResetBtn');
|
||||
// --- 结束:新增 ---
|
||||
// const safetySettingsContainer = document.getElementById('SAFETY_SETTINGS_container'); // Moved outside
|
||||
|
||||
|
||||
// 打开模态框
|
||||
@@ -111,8 +131,14 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
if (event.target == bulkDeleteApiKeyModal) { // 新增对批量删除模态框的处理
|
||||
bulkDeleteApiKeyModal.classList.remove('show');
|
||||
}
|
||||
if (event.target == proxyModal) { // 新增对代理模态框的处理
|
||||
proxyModal.classList.remove('show');
|
||||
}
|
||||
if (event.target == bulkDeleteProxyModal) { // 新增对批量删除代理模态框的处理
|
||||
bulkDeleteProxyModal.classList.remove('show');
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
// 确认添加 API Key
|
||||
if (confirmAddApiKeyBtn) {
|
||||
confirmAddApiKeyBtn.addEventListener('click', handleBulkAddApiKeys);
|
||||
@@ -158,7 +184,77 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
}
|
||||
// --- 结束:批量删除 API Key 相关 ---
|
||||
// --- 结束:API Key 相关 ---
|
||||
|
||||
// --- 新增:Proxy 模态框事件 ---
|
||||
// 打开模态框 (Changed event listener to addProxyBtn)
|
||||
if (addProxyBtn) {
|
||||
addProxyBtn.addEventListener('click', () => {
|
||||
if (proxyModal) {
|
||||
proxyModal.classList.add('show');
|
||||
}
|
||||
if (proxyBulkInput) proxyBulkInput.value = ''; // 清空输入框
|
||||
});
|
||||
}
|
||||
|
||||
// 关闭模态框 (X 按钮)
|
||||
if (closeProxyModalBtn) {
|
||||
closeProxyModalBtn.addEventListener('click', () => {
|
||||
if (proxyModal) {
|
||||
proxyModal.classList.remove('show');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 关闭模态框 (取消按钮)
|
||||
if (cancelAddProxyBtn) {
|
||||
cancelAddProxyBtn.addEventListener('click', () => {
|
||||
if (proxyModal) {
|
||||
proxyModal.classList.remove('show');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 确认添加 Proxy
|
||||
if (confirmAddProxyBtn) {
|
||||
confirmAddProxyBtn.addEventListener('click', handleBulkAddProxies);
|
||||
}
|
||||
// --- 结束:Proxy 模态框事件 ---
|
||||
|
||||
// --- 新增:批量删除 Proxy 相关事件 ---
|
||||
// 打开批量删除模态框
|
||||
if (bulkDeleteProxyBtn) {
|
||||
bulkDeleteProxyBtn.addEventListener('click', () => {
|
||||
if (bulkDeleteProxyModal) {
|
||||
bulkDeleteProxyModal.classList.add('show');
|
||||
}
|
||||
if (bulkDeleteProxyInput) bulkDeleteProxyInput.value = ''; // 清空输入框
|
||||
});
|
||||
}
|
||||
|
||||
// 关闭批量删除模态框 (X 按钮)
|
||||
if (closeBulkDeleteProxyModalBtn) {
|
||||
closeBulkDeleteProxyModalBtn.addEventListener('click', () => {
|
||||
if (bulkDeleteProxyModal) {
|
||||
bulkDeleteProxyModal.classList.remove('show');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 关闭批量删除模态框 (取消按钮)
|
||||
if (cancelBulkDeleteProxyBtn) {
|
||||
cancelBulkDeleteProxyBtn.addEventListener('click', () => {
|
||||
if (bulkDeleteProxyModal) {
|
||||
bulkDeleteProxyModal.classList.remove('show');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 确认批量删除 Proxy
|
||||
if (confirmBulkDeleteProxyBtn) {
|
||||
confirmBulkDeleteProxyBtn.addEventListener('click', handleBulkDeleteProxies);
|
||||
}
|
||||
// --- 结束:批量删除 Proxy 相关 ---
|
||||
|
||||
// --- 新增:重置确认模态框事件监听 (移到 DOMContentLoaded 内部) ---
|
||||
if (closeResetModalBtn) {
|
||||
closeResetModalBtn.addEventListener('click', () => {
|
||||
@@ -206,7 +302,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
// --- 结束:思考模型预算映射相关 ---
|
||||
|
||||
// 添加事件委托,处理动态添加的 THINKING_MODELS 输入框的 input 事件
|
||||
const thinkingModelsContainer = document.getElementById('THINKING_MODELS_container');
|
||||
// const thinkingModelsContainer = document.getElementById('THINKING_MODELS_container'); // Moved outside
|
||||
if (thinkingModelsContainer) {
|
||||
thinkingModelsContainer.addEventListener('input', function(event) {
|
||||
if (event.target && event.target.classList.contains('array-input') && event.target.closest('.array-item[data-model-id]')) {
|
||||
@@ -220,6 +316,12 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
});
|
||||
}
|
||||
|
||||
// --- 新增:安全设置相关 ---
|
||||
const addSafetySettingBtn = document.getElementById('addSafetySettingBtn');
|
||||
if (addSafetySettingBtn) {
|
||||
addSafetySettingBtn.addEventListener('click', () => addSafetySettingItem());
|
||||
}
|
||||
// --- 结束:安全设置相关 ---
|
||||
|
||||
}); // <-- DOMContentLoaded 结束括号
|
||||
|
||||
@@ -265,6 +367,10 @@ async function initConfig() {
|
||||
if (!config.FILTERED_MODELS || !Array.isArray(config.FILTERED_MODELS) || config.FILTERED_MODELS.length === 0) {
|
||||
config.FILTERED_MODELS = ['gemini-1.0-pro-latest'];
|
||||
}
|
||||
// --- 新增:处理 PROXIES 默认值 ---
|
||||
if (!config.PROXIES || !Array.isArray(config.PROXIES)) {
|
||||
config.PROXIES = []; // 默认为空数组
|
||||
}
|
||||
// --- 新增:处理新字段的默认值 ---
|
||||
if (!config.THINKING_MODELS || !Array.isArray(config.THINKING_MODELS)) {
|
||||
config.THINKING_MODELS = []; // 默认为空数组
|
||||
@@ -272,7 +378,11 @@ async function initConfig() {
|
||||
if (!config.THINKING_BUDGET_MAP || typeof config.THINKING_BUDGET_MAP !== 'object' || config.THINKING_BUDGET_MAP === null) {
|
||||
config.THINKING_BUDGET_MAP = {}; // 默认为空对象
|
||||
}
|
||||
// --- 结束:处理新字段的默认值 ---
|
||||
// --- 新增:处理 SAFETY_SETTINGS 默认值 ---
|
||||
if (!config.SAFETY_SETTINGS || !Array.isArray(config.SAFETY_SETTINGS)) {
|
||||
config.SAFETY_SETTINGS = []; // 默认为空数组
|
||||
}
|
||||
// --- 结束:处理 SAFETY_SETTINGS 默认值 ---
|
||||
|
||||
populateForm(config);
|
||||
|
||||
@@ -296,6 +406,7 @@ async function initConfig() {
|
||||
SEARCH_MODELS: ['gemini-1.5-flash-latest'],
|
||||
FILTERED_MODELS: ['gemini-1.0-pro-latest'],
|
||||
UPLOAD_PROVIDER: 'smms',
|
||||
PROXIES: [], // 添加默认值
|
||||
THINKING_MODELS: [],
|
||||
THINKING_BUDGET_MAP: {}
|
||||
};
|
||||
@@ -410,6 +521,24 @@ function populateForm(config) {
|
||||
if (uploadProvider) {
|
||||
toggleProviderConfig(uploadProvider.value);
|
||||
}
|
||||
|
||||
// --- 新增:填充 SAFETY_SETTINGS ---
|
||||
let safetyItemsAdded = false;
|
||||
if (safetySettingsContainer && Array.isArray(config.SAFETY_SETTINGS)) {
|
||||
config.SAFETY_SETTINGS.forEach(setting => {
|
||||
if (setting && typeof setting === 'object' && setting.category && setting.threshold) {
|
||||
addSafetySettingItem(setting.category, setting.threshold);
|
||||
safetyItemsAdded = true;
|
||||
} else {
|
||||
console.warn("Invalid safety setting item found:", setting);
|
||||
}
|
||||
});
|
||||
}
|
||||
// 如果没有添加任何安全设置项,则显示占位符
|
||||
if (safetySettingsContainer && !safetyItemsAdded) {
|
||||
safetySettingsContainer.innerHTML = '<div class="text-gray-500 text-sm italic">定义模型的安全过滤阈值。</div>';
|
||||
}
|
||||
// --- 结束:填充 SAFETY_SETTINGS ---
|
||||
}
|
||||
|
||||
// --- 新增:处理批量添加 API Key 的逻辑 ---
|
||||
@@ -521,6 +650,92 @@ function handleBulkDeleteApiKeys() {
|
||||
bulkDeleteTextarea.value = '';
|
||||
}
|
||||
|
||||
// --- 新增:处理批量添加 Proxy 的逻辑 ---
|
||||
function handleBulkAddProxies() {
|
||||
const proxyBulkInput = document.getElementById('proxyBulkInput');
|
||||
const proxyContainer = document.getElementById('PROXIES_container');
|
||||
const proxyModal = document.getElementById('proxyModal');
|
||||
|
||||
if (!proxyBulkInput || !proxyContainer || !proxyModal) return;
|
||||
|
||||
const bulkText = proxyBulkInput.value;
|
||||
// 匹配 http(s):// 或 socks5:// 格式的代理,允许包含用户名密码
|
||||
const proxyRegex = /(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g;
|
||||
const extractedProxies = bulkText.match(proxyRegex) || [];
|
||||
|
||||
// 获取当前已有的 proxies
|
||||
const currentProxyInputs = proxyContainer.querySelectorAll('.array-input');
|
||||
const currentProxies = Array.from(currentProxyInputs).map(input => input.value).filter(proxy => proxy.trim() !== '');
|
||||
|
||||
// 合并并去重
|
||||
const combinedProxies = new Set([...currentProxies, ...extractedProxies]);
|
||||
const uniqueProxies = Array.from(combinedProxies);
|
||||
|
||||
// 清空现有列表显示
|
||||
const existingItems = proxyContainer.querySelectorAll('.array-item');
|
||||
existingItems.forEach(item => item.remove());
|
||||
|
||||
// 重新填充列表
|
||||
uniqueProxies.forEach(proxy => {
|
||||
addArrayItemWithValue('PROXIES', proxy);
|
||||
});
|
||||
|
||||
// 关闭模态框
|
||||
proxyModal.classList.remove('show');
|
||||
showNotification(`添加/更新了 ${uniqueProxies.length} 个唯一代理`, 'success');
|
||||
}
|
||||
// --- 结束:处理批量添加 Proxy 的逻辑 ---
|
||||
|
||||
// --- 新增:处理批量删除 Proxy 的逻辑 ---
|
||||
function handleBulkDeleteProxies() {
|
||||
const bulkDeleteTextarea = document.getElementById('bulkDeleteProxyInput');
|
||||
const proxyContainer = document.getElementById('PROXIES_container');
|
||||
const bulkDeleteModal = document.getElementById('bulkDeleteProxyModal');
|
||||
|
||||
if (!bulkDeleteTextarea || !proxyContainer || !bulkDeleteModal) return;
|
||||
|
||||
const bulkText = bulkDeleteTextarea.value;
|
||||
if (!bulkText.trim()) {
|
||||
showNotification('请粘贴需要删除的代理地址', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
// 使用与添加时相同的正则表达式来提取要删除的代理
|
||||
const proxyRegex = /(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g;
|
||||
const proxiesToDelete = new Set(bulkText.match(proxyRegex) || []); // 使用 Set 进行高效查找
|
||||
|
||||
if (proxiesToDelete.size === 0) {
|
||||
showNotification('未在输入内容中提取到有效的代理地址格式', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
const proxyItems = proxyContainer.querySelectorAll('.array-item');
|
||||
let deleteCount = 0;
|
||||
|
||||
proxyItems.forEach(item => {
|
||||
const input = item.querySelector('.array-input');
|
||||
// 检查输入框是否存在及其值是否在要删除的集合中
|
||||
if (input && proxiesToDelete.has(input.value)) {
|
||||
item.remove(); // 删除整个数组项元素
|
||||
deleteCount++;
|
||||
}
|
||||
});
|
||||
|
||||
// 关闭模态框
|
||||
bulkDeleteModal.classList.remove('show');
|
||||
|
||||
// 提供反馈
|
||||
if (deleteCount > 0) {
|
||||
showNotification(`成功删除了 ${deleteCount} 个匹配的代理`, 'success');
|
||||
} else {
|
||||
showNotification('列表中未找到您输入的任何代理进行删除', 'info');
|
||||
}
|
||||
|
||||
// 处理后清空文本区域
|
||||
bulkDeleteTextarea.value = '';
|
||||
}
|
||||
// --- 结束:处理批量删除 Proxy 的逻辑 ---
|
||||
|
||||
// 切换标签
|
||||
function switchTab(tabId) {
|
||||
// 更新标签按钮状态
|
||||
@@ -781,6 +996,23 @@ function collectFormData() {
|
||||
}
|
||||
// --- 结束:处理 THINKING_BUDGET_MAP ---
|
||||
|
||||
// --- 新增:处理 SAFETY_SETTINGS ---
|
||||
if (safetySettingsContainer) {
|
||||
formData['SAFETY_SETTINGS'] = [];
|
||||
const settingItems = safetySettingsContainer.querySelectorAll('.safety-setting-item');
|
||||
settingItems.forEach(item => {
|
||||
const categorySelect = item.querySelector('.safety-category-select');
|
||||
const thresholdSelect = item.querySelector('.safety-threshold-select');
|
||||
if (categorySelect && thresholdSelect && categorySelect.value && thresholdSelect.value) {
|
||||
formData['SAFETY_SETTINGS'].push({
|
||||
category: categorySelect.value,
|
||||
threshold: thresholdSelect.value
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
// --- 结束:处理 SAFETY_SETTINGS ---
|
||||
|
||||
return formData;
|
||||
}
|
||||
|
||||
@@ -975,10 +1207,6 @@ function generateRandomToken() {
|
||||
}
|
||||
// --- 结束:生成随机令牌函数 ---
|
||||
|
||||
// --- 修改:添加思考模型预算映射项 (现在由添加思考模型触发) ---
|
||||
// function addBudgetMapItem() {
|
||||
// // 不再需要手动添加
|
||||
// }
|
||||
|
||||
// Deprecated: This function is now effectively replaced by createAndAppendBudgetMapItem
|
||||
// for the initial population logic. It delegates to the new function if called.
|
||||
@@ -988,3 +1216,86 @@ function addBudgetMapItemWithValue(mapKey, mapValue, modelId) {
|
||||
createAndAppendBudgetMapItem(mapKey, mapValue, modelId);
|
||||
}
|
||||
/* --- 结束:(addBudgetMapItemWithValue 已弃用) --- */
|
||||
|
||||
|
||||
// --- 新增:添加安全设置项的函数 ---
|
||||
function addSafetySettingItem(category = '', threshold = '') {
|
||||
const container = document.getElementById('SAFETY_SETTINGS_container');
|
||||
if (!container) {
|
||||
console.error("Cannot add safety setting: SAFETY_SETTINGS_container not found!");
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果容器当前只有占位符,则清除它
|
||||
const placeholder = container.querySelector('.text-gray-500.italic');
|
||||
if (placeholder && container.children.length === 1 && container.firstChild === placeholder) {
|
||||
container.innerHTML = '';
|
||||
}
|
||||
|
||||
const harmCategories = [
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY" // 根据需要添加或移除
|
||||
];
|
||||
const harmThresholds = [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"OFF" // 根据 Google API 文档添加或移除
|
||||
];
|
||||
|
||||
const settingItem = document.createElement('div');
|
||||
settingItem.className = 'safety-setting-item flex items-center mb-2 gap-2';
|
||||
|
||||
// Category Select
|
||||
const categorySelect = document.createElement('select');
|
||||
categorySelect.className = 'safety-category-select flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 bg-white';
|
||||
harmCategories.forEach(cat => {
|
||||
const option = document.createElement('option');
|
||||
option.value = cat;
|
||||
option.textContent = cat.replace('HARM_CATEGORY_', ''); // 显示更友好的名称
|
||||
if (cat === category) {
|
||||
option.selected = true;
|
||||
}
|
||||
categorySelect.appendChild(option);
|
||||
});
|
||||
|
||||
// Threshold Select
|
||||
const thresholdSelect = document.createElement('select');
|
||||
thresholdSelect.className = 'safety-threshold-select w-48 px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 bg-white';
|
||||
harmThresholds.forEach(thr => {
|
||||
const option = document.createElement('option');
|
||||
option.value = thr;
|
||||
option.textContent = thr.replace('BLOCK_', '').replace('_AND_ABOVE', '+'); // 简化显示
|
||||
if (thr === threshold) {
|
||||
option.selected = true;
|
||||
}
|
||||
thresholdSelect.appendChild(option);
|
||||
});
|
||||
|
||||
// Remove Button
|
||||
const removeBtn = document.createElement('button');
|
||||
removeBtn.type = 'button';
|
||||
removeBtn.className = 'remove-btn text-gray-400 hover:text-red-500 focus:outline-none transition-colors duration-150';
|
||||
removeBtn.innerHTML = '<i class="fas fa-trash-alt"></i>';
|
||||
removeBtn.title = '删除此设置';
|
||||
removeBtn.addEventListener('click', function() {
|
||||
const currentItem = this.closest('.safety-setting-item');
|
||||
currentItem.remove();
|
||||
// 检查容器是否为空,如果是,则添加回占位符
|
||||
if (container.children.length === 0) {
|
||||
container.innerHTML = '<div class="text-gray-500 text-sm italic">定义模型的安全过滤阈值。</div>';
|
||||
}
|
||||
});
|
||||
|
||||
settingItem.appendChild(categorySelect);
|
||||
settingItem.appendChild(thresholdSelect);
|
||||
settingItem.appendChild(removeBtn);
|
||||
|
||||
container.appendChild(settingItem);
|
||||
}
|
||||
// --- 结束:添加安全设置项的函数 ---
|
||||
|
||||
|
||||
@@ -17,13 +17,27 @@ self.addEventListener('install', event => {
|
||||
|
||||
self.addEventListener('fetch', event => {
|
||||
event.respondWith(
|
||||
caches.match(event.request)
|
||||
.then(response => {
|
||||
if (response) {
|
||||
return response;
|
||||
}
|
||||
return fetch(event.request);
|
||||
})
|
||||
caches.open(CACHE_NAME).then(cache => {
|
||||
// 1. 尝试从缓存获取
|
||||
return cache.match(event.request).then(responseFromCache => {
|
||||
// 2. 同时从网络获取 (后台进行)
|
||||
const fetchPromise = fetch(event.request).then(responseFromNetwork => {
|
||||
// 3. 网络请求成功,更新缓存
|
||||
cache.put(event.request, responseFromNetwork.clone());
|
||||
return responseFromNetwork;
|
||||
}).catch(err => {
|
||||
// 网络请求失败时,可以选择记录错误或不执行任何操作
|
||||
console.error('Network fetch failed:', err);
|
||||
// 确保即使网络失败,如果缓存存在,我们仍然返回缓存
|
||||
// 如果缓存也不存在,则此 Promise 会 reject
|
||||
throw err;
|
||||
});
|
||||
|
||||
// 4. 如果缓存存在,立即返回缓存;否则等待网络响应
|
||||
// 后台的网络请求仍在进行,用于更新缓存
|
||||
return responseFromCache || fetchPromise;
|
||||
});
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -182,6 +182,22 @@
|
||||
<input type="number" id="MAX_RETRIES" name="MAX_RETRIES" min="0" max="10" class="w-full px-4 py-3 rounded-lg border border-gray-300 focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50">
|
||||
<small class="text-gray-500 mt-1 block">API请求失败后的最大重试次数</small>
|
||||
</div>
|
||||
<!-- 代理服务器列表 -->
|
||||
<div class="mb-6">
|
||||
<label for="PROXIES" class="block font-semibold mb-2 text-gray-700">代理服务器列表</label>
|
||||
<div class="array-container bg-white rounded-lg border border-gray-200 p-4 mb-2" id="PROXIES_container">
|
||||
<!-- 代理项将在这里动态添加 -->
|
||||
</div>
|
||||
<div class="flex justify-end gap-2">
|
||||
<button type="button" class="bg-danger-600 hover:bg-danger-700 text-white px-4 py-2 rounded-lg font-medium transition-all duration-200 flex items-center gap-2" id="bulkDeleteProxyBtn">
|
||||
<i class="fas fa-trash-alt"></i> 删除代理
|
||||
</button>
|
||||
<button type="button" class="bg-primary-600 hover:bg-primary-700 text-white px-4 py-2 rounded-lg font-medium transition-all duration-200 flex items-center gap-2" id="addProxyBtn">
|
||||
<i class="fas fa-plus"></i> 添加代理
|
||||
</button>
|
||||
</div>
|
||||
<small class="text-gray-500 mt-1 block">代理服务器列表,支持 http 和 socks5 格式,例如: http://user:pass@host:port 或 socks5://host:port。点击按钮可批量添加或删除。</small>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模型相关配置 -->
|
||||
@@ -295,6 +311,20 @@
|
||||
</div> -->
|
||||
<small class="text-gray-500 mt-1 block">为每个思考模型设置预算(整数,最大值 24576),此项与上方模型列表自动关联。</small>
|
||||
</div>
|
||||
<!-- 安全设置 -->
|
||||
<div class="mb-6">
|
||||
<label for="SAFETY_SETTINGS" class="block font-semibold mb-2 text-gray-700">安全设置 (Safety Settings)</label>
|
||||
<div class="bg-white rounded-lg border border-gray-200 p-4 mb-2 space-y-3" id="SAFETY_SETTINGS_container">
|
||||
<!-- 安全设置项将在这里动态添加 -->
|
||||
<div class="text-gray-500 text-sm italic">定义模型的安全过滤阈值。</div>
|
||||
</div>
|
||||
<div class="flex justify-end">
|
||||
<button type="button" class="bg-primary-600 hover:bg-primary-700 text-white px-4 py-2 rounded-lg font-medium transition-all duration-200 flex items-center gap-2" id="addSafetySettingBtn">
|
||||
<i class="fas fa-plus"></i> 添加安全设置
|
||||
</button>
|
||||
</div>
|
||||
<small class="text-gray-500 mt-1 block">配置模型的安全过滤级别,例如 HARM_CATEGORY_HARASSMENT: BLOCK_NONE。</small>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 图像生成相关配置 -->
|
||||
@@ -511,6 +541,43 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Proxy Add Modal -->
|
||||
<div id="proxyModal" class="modal">
|
||||
<div class="w-full max-w-lg mx-auto bg-white rounded-2xl shadow-2xl overflow-hidden animate-fade-in">
|
||||
<div class="p-6">
|
||||
<div class="flex justify-between items-center mb-4">
|
||||
<h2 class="text-xl font-bold text-gray-800">批量添加代理服务器</h2>
|
||||
<button id="closeProxyModalBtn" class="text-gray-400 hover:text-gray-600 text-xl">×</button>
|
||||
</div>
|
||||
<p class="text-gray-600 mb-4">每行粘贴一个或多个代理地址,将自动提取有效地址并去重。</p>
|
||||
<textarea id="proxyBulkInput" rows="10" placeholder="在此处粘贴代理地址 (例如 http://user:pass@host:port 或 socks5://host:port)..." class="w-full px-4 py-3 rounded-lg border border-gray-300 focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 font-mono text-sm"></textarea>
|
||||
<div class="flex justify-end gap-3 mt-6">
|
||||
<button type="button" id="confirmAddProxyBtn" class="bg-primary-600 hover:bg-primary-700 text-white px-6 py-2 rounded-lg font-medium transition">确认添加</button>
|
||||
<button type="button" id="cancelAddProxyBtn" class="bg-gray-200 hover:bg-gray-300 text-gray-700 px-6 py-2 rounded-lg font-medium transition">取消</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bulk Delete Proxy Modal -->
|
||||
<div id="bulkDeleteProxyModal" class="modal">
|
||||
<div class="w-full max-w-lg mx-auto bg-white rounded-2xl shadow-2xl overflow-hidden animate-fade-in">
|
||||
<div class="p-6">
|
||||
<div class="flex justify-between items-center mb-4">
|
||||
<h2 class="text-xl font-bold text-gray-800">批量删除代理服务器</h2>
|
||||
<button id="closeBulkDeleteProxyModalBtn" class="text-gray-400 hover:text-gray-600 text-xl">×</button>
|
||||
</div>
|
||||
<p class="text-gray-600 mb-4">每行粘贴一个或多个代理地址,将自动提取有效地址并从列表中删除。</p>
|
||||
<textarea id="bulkDeleteProxyInput" rows="10" placeholder="在此处粘贴要删除的代理地址..." class="w-full px-4 py-3 rounded-lg border border-gray-300 focus:border-danger-500 focus:ring focus:ring-danger-200 focus:ring-opacity-50 font-mono text-sm"></textarea>
|
||||
<div class="flex justify-end gap-3 mt-6">
|
||||
<button type="button" id="confirmBulkDeleteProxyBtn" class="bg-danger-600 hover:bg-danger-700 text-white px-6 py-2 rounded-lg font-medium transition">确认删除</button>
|
||||
<button type="button" id="cancelBulkDeleteProxyBtn" class="bg-gray-200 hover:bg-gray-300 text-gray-700 px-6 py-2 rounded-lg font-medium transition">取消</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Reset Confirmation Modal -->
|
||||
<div id="resetConfirmModal" class="modal">
|
||||
<div class="w-full max-w-md mx-auto bg-white rounded-2xl shadow-2xl overflow-hidden animate-fade-in">
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
fastapi
|
||||
httpx
|
||||
httpx[socks]
|
||||
openai
|
||||
pydantic
|
||||
pydantic_settings
|
||||
@@ -16,6 +16,5 @@ sqlalchemy
|
||||
aiomysql
|
||||
databases
|
||||
python-dotenv
|
||||
apscheduler # 添加定时任务库
|
||||
|
||||
apscheduler
|
||||
packaging
|
||||
|
||||
Reference in New Issue
Block a user