Compare commits

...

11 Commits

Author SHA1 Message Date
snaily
929592bbc4 chore: 更新版本号至 2.1.2 2025-05-02 22:49:50 +08:00
snaily
2225a40bbe feat: 增加 Gemini 安全设置支持
- 新增 `SAFETY_SETTINGS` 配置项,允许用户通过环境变量或数据库配置 Gemini 模型的安全过滤级别。
- 更新后端服务 (`config.py`, `constants.py`, `gemini_routes.py`, `openai_routes.py`, `openai_chat_service.py`, `api_client.py`, `model_service.py`) 以支持和传递 `safety_settings` 参数。
- 在配置编辑器前端 (`config_editor.js`, `config_editor.html`) 添加了用于管理安全设置的用户界面。
- 将模型获取逻辑 (`model_service.py`, `api_client.py`) 改为异步。
- 优化 Service Worker (`service-worker.js`) 的缓存策略为 "cache then network"。

Bump version to 2.1.2
2025-05-02 22:49:36 +08:00
snaily
3480fa3b0f Merge branch 'pr/tbphp/74' 2025-05-02 18:17:50 +08:00
tbphp
d7113f5fc4 fix: 修复安全设置对输出速度的影响 2025-05-02 17:07:50 +08:00
snaily
2072f54ca1 refactor: 重构错误处理并优化路由与服务结构
主要变更:
- 新增 `app/handler/error_handler.py`,引入 `handle_route_errors` 异步上下文管理器,用于统一处理路由中的错误和日志记录。
- 在 `openai_routes` 和 `openai_compatiable_routes` 中应用 `handle_route_errors`,移除冗余的 try-except 块,简化路由逻辑。
- 将 `OpenAICompatiableService` 移动到 `app/service/openai_compatiable/` 目录下。
- 将 `StatsService` 移动到 `app/service/stats/` 目录下,并更新相关导入路径。
- 修复 `response_handler` 中处理 Gemini API 响应时 `inlineData` 字段的错误(原为 `inline_data`)。
- 修复 `openai_routes` 和 `openai_compatiable_routes` 中处理图像生成聊天(如 imagen3-chat)时未正确使用付费 API key 的问题。
- 在 `requirements.txt` 中将 `httpx` 更改为 `httpx[socks]`,以增加 SOCKS 代理支持。
2025-05-02 01:20:05 +08:00
snaily
7c9b721164 chore:更新 README.md,在 API 端点部分添加新的 OpenAI 兼容接口信息。 2025-04-30 20:49:14 +08:00
snaily
83ce50975a feat: 实现 OpenAI 兼容 API 端点和批量代理删除
新增与 OpenAI 规范兼容的 API 端点:
- `/openai/v1/models`
- `/openai/v1/chat/completions` (支持流式传输、重试和密钥切换)
- `/openai/v1/embeddings`
- `/openai/v1/images/generations`

包含:
- 在 `app/router/openai_compatiable_routes.py` 中新增路由。
- `OpenAICompatiableService` 用于处理请求逻辑、日志记录和错误管理。
- 更新 `OpenaiApiClient` 以支持新方法和代理使用。
- 修改 `app/domain/openai_models.py` 以实现兼容性。
- 为新 API 添加专用日志记录器 (`openai_compatible`)。
- 为新路由 (`/openai`, `/api/version/check`) 添加认证中间件豁免。

增强配置编辑器 UI:
- 在 `app/static/js/config_editor.js` 和 `app/templates/config_editor.html` 中添加批量代理删除功能。
2025-04-30 20:39:47 +08:00
snaily
7da9110704 feat: 添加代理支持 (HTTP/SOCKS5)
为应用程序添加了通过代理服务器访问 Gemini API 的功能。

主要变更包括:

*   **配置**:
    *   在 `.env.example` 和 `app/config/config.py` 中添加了 `PROXIES` 配置项,允许用户指定一个或多个 HTTP 或 SOCKS5 代理服务器列表。
    *   更新 `README.md` 以包含关于代理配置的说明。
*   **后端**:
    *   修改 `app/service/client/api_client.py` 中的 `GeminiApiClient`,使其在发起请求时能从配置的 `PROXIES` 列表中随机选择一个代理使用。
    *   添加了 `app/log/logger.py` 中的 `get_api_client_logger`,用于记录 API 客户端(包括代理使用)的相关日志。
*   **前端**:
    *   在 `app/templates/config_editor.html` 配置编辑器页面添加了代理列表的显示区域和“添加代理”按钮。
    *   实现了用于批量添加代理的模态框 UI。
    *   在 `app/static/js/config_editor.js` 中添加了处理代理列表显示、打开/关闭模态框以及处理批量添加代理(包括提取、去重和更新 UI)的 JavaScript 逻辑。
    *   确保在初始化配置时为 `PROXIES` 设置默认空列表。

此功能使得用户可以在需要通过代理访问外部网络的环境下使用该应用。
2025-04-30 10:57:17 +08:00
snaily
e9d19de7c6 refactor: 迁移媒体常量并重构相关处理逻辑
将音频/视频相关的配置(支持格式、大小限制、MIME类型)从 `config.py` 移动到 `core/constants.py`,以集中管理常量。

更新 `message_converter.py`:
- 从 `core.constants` 导入媒体常量。
- 添加并使用 `message_converter` 的专用日志记录器。
- 清理导入和代码格式。

更新 `openai_chat_service.py`:
- 调整 `_has_media_parts` 函数以正确检测 `inline_data`。
- 清理导入和代码格式。

在 `log/logger.py` 中添加 `get_message_converter_logger` 函数。

对 `config.py` 和 `response_handler.py` 进行了相关的移除和微小的代码清理。
2025-04-29 17:54:48 +08:00
Your Name (aider)
e822831178 fix: remove duplicate convert method in message converter 2025-04-26 03:35:16 +00:00
Your Name (aider)
775930edce feat: add support for audio and video input via base64
This commit adds configuration and conversion logic to handle audio and video inputs in base64 format, similar to existing image support. It includes:

1. Added supported formats and size limits in config
2. Implemented media validation and conversion in message converter
3. Updated payload building to handle media parts
4. Improved error handling and logging for media processing
2025-04-26 03:07:54 +00:00
26 changed files with 1608 additions and 419 deletions

View File

@@ -23,6 +23,9 @@ CHECK_INTERVAL_HOURS=1
TIMEZONE=Asia/Shanghai
# 请求超时时间(秒)
TIME_OUT=300
# 代理服务器配置 (支持 http 和 socks5)
# 示例: PROXIES=["http://user:pass@host:port", "socks5://host:port"]
PROXIES=[]
#########################image_generate 相关配置###########################
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
@@ -44,3 +47,7 @@ STREAM_CHUNK_SIZE=5
# 日志级别 (debug, info, warning, error, critical),默认为 info
LOG_LEVEL=info
##########################################################################
# 安全设置 (JSON 字符串格式)
# 注意:这里的示例值可能需要根据实际模型支持情况调整
SAFETY_SETTINGS='[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}]'

View File

@@ -67,6 +67,7 @@ app/
>镜像地址: docker pull ghcr.io/snailyp/gemini-balance:latest
* **模型列表自动维护**: 支持openai和gemini模型列表获取与newapi自动获取模型列表完美兼容无需手动填写。
* **支持移除不使用的模型**: 默认提供的模型太多,很多用不上,可以通过`FILTERED_MODELS`过滤掉。
* **代理支持**: 支持配置 HTTP/SOCKS5 代理服务器 (`PROXIES`),用于访问 Gemini API方便在特殊网络环境下使用。支持批量添加代理。
## 🚀 快速开始
@@ -166,6 +167,7 @@ app/
| `CHECK_INTERVAL_HOURS` | 可选,检查禁用 Key 是否恢复的时间间隔 (小时) | `1` |
| `TIMEZONE` | 可选,应用程序使用的时区 | `Asia/Shanghai` |
| `TIME_OUT` | 可选,请求超时时间 (秒) | `300` |
| `PROXIES` | 可选,代理服务器列表 (例如 `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
| `LOG_LEVEL` | 可选,日志级别,例如 DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
| **图像生成相关** | | |
| `PAID_KEY` | 可选付费版API Key用于图片生成等高级功能 | `your-paid-api-key` |
@@ -193,12 +195,16 @@ app/
* `POST /models/{model_name}:generateContent`: 使用指定的 Gemini 模型生成内容。
* `POST /models/{model_name}:streamGenerateContent`: 使用指定的 Gemini 模型流式生成内容。
### OpenAI API 相关 (`(/hf)/v1`)
### OpenAI API 相关
* `GET /v1/models`: 列出可用的 OpenAI 模型。
* `POST /v1/chat/completions`: 通过 OpenAI API 进行聊天补全。
* `POST /v1/images/generations`: 通过 OpenAI API 生成图像
* `POST /v1/embeddings`: 通过 OpenAI API 创建文本嵌入
* `GET (/hf)/v1/models`: 列出可用的模型 (底层用的gemini格式)
* `POST (/hf)/v1/chat/completions`: 进行聊天补全 (底层用的gemini格式, 支持流式传输)
* `POST (/hf)/v1/embeddings`: 创建文本嵌入 (底层用的gemini格式)
* `POST (/hf)/v1/images/generations`: 生成图像 (底层用的gemini格式)
* `GET /openai/v1/models`: 列出可用的模型 (底层用的openai格式)。
* `POST /openai/v1/chat/completions`: 进行聊天补全 (底层用的openai格式, 支持流式传输, 可防止截断,速度也快)。
* `POST /openai/v1/embeddings`: 创建文本嵌入 (底层用的openai格式)。
* `POST /openai/v1/images/generations`: 生成图像 (底层用的openai格式)。
## 🤝 贡献

View File

@@ -1 +1 @@
2.1.0
2.1.2

View File

@@ -9,7 +9,7 @@ from pydantic import ValidationError
from pydantic_settings import BaseSettings
from sqlalchemy import insert, update, select
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT, MAX_RETRIES
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_SAFETY_SETTINGS, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT, MAX_RETRIES
from app.log.logger import Logger
@@ -30,7 +30,8 @@ class Settings(BaseSettings):
TEST_MODEL: str = DEFAULT_MODEL
TIME_OUT: int = DEFAULT_TIMEOUT
MAX_RETRIES: int = MAX_RETRIES
PROXIES: List[str] = [] # 新增:代理服务器列表
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
@@ -68,6 +69,7 @@ class Settings(BaseSettings):
# 日志配置
LOG_LEVEL: str = "INFO" # 默认日志级别
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS # 新增:安全设置
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -120,6 +122,32 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
# Log other errors (ValueError, TypeError) or JSON errors without single quotes
logger.error(f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict.")
return parsed_dict # Return the parsed dict or an empty one if all attempts fail
# 处理 List[Dict[str, str]]
elif target_type == List[Dict[str, str]]:
try:
parsed = json.loads(db_value)
if isinstance(parsed, list):
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
valid = all(
isinstance(item, dict) and
all(isinstance(k, str) for k in item.keys()) and
all(isinstance(v, str) for v in item.values())
for item in parsed
)
if valid:
return parsed
else:
logger.warning(f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}")
return [] # 或者返回默认值?这里返回空列表
else:
logger.warning(f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}")
return []
except json.JSONDecodeError:
logger.error(f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list.")
return []
except Exception as e:
logger.error(f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list.")
return []
# 处理 bool
elif target_type == bool:
return db_value.lower() in ('true', '1', 'yes', 'on')

View File

@@ -40,3 +40,40 @@ DEFAULT_STREAM_CHUNK_SIZE = 5
# 正则表达式模式
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
# Audio/Video Settings
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
# Optional: Define MIME type mappings if needed, or handle directly in converter
AUDIO_FORMAT_TO_MIMETYPE = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"flac": "audio/flac",
"ogg": "audio/ogg",
}
VIDEO_FORMAT_TO_MIMETYPE = {
"mp4": "video/mp4",
"mov": "video/quicktime",
"avi": "video/x-msvideo",
"webm": "video/webm",
}
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
DEFAULT_SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]

View File

@@ -1,5 +1,5 @@
from pydantic import BaseModel
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
@@ -9,11 +9,14 @@ class ChatRequest(BaseModel):
model: str = DEFAULT_MODEL
temperature: Optional[float] = DEFAULT_TEMPERATURE
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
max_tokens: Optional[int] = None
top_p: Optional[float] = DEFAULT_TOP_P
top_k: Optional[int] = DEFAULT_TOP_K
stop: Optional[List[str]] = []
stop: Optional[Union[List[str],str]] = None
reasoning_effort: Optional[str] = None
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
tool_choice: Optional[str] = None
response_format: Optional[dict] = None
class EmbeddingRequest(BaseModel):
@@ -23,10 +26,10 @@ class EmbeddingRequest(BaseModel):
class ImageGenerationRequest(BaseModel):
model: str = "DALL-E-3"
model: str = "imagen-3.0-generate-002"
prompt: str = ""
n: int = 1
size: Optional[str] = "1024x1024"
quality: Optional[str] = ""
style: Optional[str] = ""
response_format: Optional[str] = "url"
quality: Optional[str] = None
style: Optional[str] = None
response_format: Optional[str] = "b64_json"

View File

@@ -0,0 +1,32 @@
from contextlib import asynccontextmanager
from fastapi import HTTPException
import logging
@asynccontextmanager
async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
"""
一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
Args:
logger: 用于记录日志的 Logger 实例。
operation_name: 操作的名称,用于日志记录和错误详情。
success_message: 操作成功时记录的自定义消息 (可选)。
failure_message: 操作失败时记录的自定义消息 (可选)。
"""
default_success_msg = f"{operation_name} request successful"
default_failure_msg = f"{operation_name} request failed"
logger.info("-" * 50 + operation_name + "-" * 50)
try:
yield
logger.info(success_message or default_success_msg)
except HTTPException as http_exc:
# 如果已经是 HTTPException直接重新抛出保留原始状态码和详情
logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
raise http_exc
except Exception as e:
# 对于其他所有异常,记录错误并抛出标准的 500 错误
logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Internal server error during {operation_name}"
) from e

View File

@@ -1,61 +1,70 @@
from abc import ABC, abstractmethod
import base64
import json
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import requests
import base64
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
import requests
from app.core.constants import (
AUDIO_FORMAT_TO_MIMETYPE,
DATA_URL_PATTERN,
IMAGE_URL_PATTERN,
MAX_AUDIO_SIZE_BYTES,
MAX_VIDEO_SIZE_BYTES,
SUPPORTED_AUDIO_FORMATS,
SUPPORTED_ROLES,
SUPPORTED_VIDEO_FORMATS,
VIDEO_FORMAT_TO_MIMETYPE,
)
from app.log.logger import get_message_converter_logger
logger = get_message_converter_logger()
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
def convert(
self, messages: List[Dict[str, Any]]
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
def _get_mime_type_and_data(base64_string):
"""
从 base64 字符串中提取 MIME 类型和数据。
参数:
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
返回:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
if base64_string.startswith("data:"):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
mime_type = (
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
mime_type, encoded_data = _get_mime_type_and_data(image_url)
return {
"inline_data": {
"mime_type": mime_type,
"data": encoded_data
}
}
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
else:
encoded_data = _convert_image_to_base64(image_url)
return {
"inline_data": {
"mime_type": "image/png",
"data": encoded_data
}
}
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
def _convert_image_to_base64(url: str) -> str:
@@ -69,7 +78,7 @@ def _convert_image_to_base64(url: str) -> str:
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
img_data = base64.b64encode(response.content).decode("utf-8")
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
@@ -93,12 +102,9 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/png",
"data": base64_data
}
})
parts.append(
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
)
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
@@ -111,42 +117,215 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
def _validate_media_data(
self, format: str, data: str, supported_formats: List[str], max_size: int
) -> tuple[Optional[str], Optional[str]]:
"""Validates format and size of Base64 media data."""
if format.lower() not in supported_formats:
logger.error(
f"Unsupported media format: {format}. Supported: {supported_formats}"
)
raise ValueError(f"Unsupported media format: {format}")
try:
# Decode Base64 to check size
# Be careful with memory usage for very large files
# Consider streaming decoding or checking length heuristic first if memory is a concern
decoded_data = base64.b64decode(
data, validate=True
) # Use validate=True for stricter check
if len(decoded_data) > max_size:
logger.error(
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
)
raise ValueError(
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
)
# No need to return decoded_data, just the original base64 if valid
return data
except base64.binascii.Error as e:
logger.error(f"Invalid Base64 data provided: {e}")
raise ValueError("Invalid Base64 data")
except Exception as e:
logger.error(f"Error validating media data: {e}")
raise
def convert(
self, messages: List[Dict[str, Any]]
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction_parts = []
for idx, msg in enumerate(messages):
role = msg.get("role", "")
parts = []
# 特别处理最后一个assistant的消息按\n\n分割
if "content" in msg and isinstance(msg["content"], str) and msg["content"] and role == "assistant" and idx == len(messages) - 2:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
if "content" in msg and isinstance(msg["content"], list):
for content_item in msg["content"]:
if not isinstance(content_item, dict):
# Skip non-dict items if any unexpected format appears
logger.warning(
f"Skipping unexpected content item format: {type(content_item)}"
)
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
content_type = content_item.get("type")
if content_type == "text" and content_item.get("text"):
parts.append({"text": content_item["text"]})
elif content_type == "image_url" and content_item.get(
"image_url", {}
).get("url"):
try:
parts.append(
_convert_image(content_item["image_url"]["url"])
)
except Exception as e:
logger.error(
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
)
# Decide how to handle: skip part, add error text, etc.
parts.append(
{
"text": f"[Error processing image: {content_item['image_url']['url']}]"
}
)
# --- Add handling for input_audio ---
elif content_type == "input_audio" and content_item.get(
"input_audio"
):
audio_info = content_item["input_audio"]
audio_data = audio_info.get("data")
audio_format = audio_info.get("format", "").lower()
if not audio_data or not audio_format:
logger.warning(
"Skipping audio part due to missing data or format."
)
continue
try:
# Validate size and format
validated_data = self._validate_media_data(
audio_format,
audio_data,
SUPPORTED_AUDIO_FORMATS,
MAX_AUDIO_SIZE_BYTES,
)
# Get MIME type
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
if not mime_type:
# Should not happen if format validation passed, but double-check
logger.error(
f"Could not find MIME type for supported format: {audio_format}"
)
raise ValueError(
f"Internal error: MIME type mapping missing for {audio_format}"
)
parts.append(
{
"inline_data": {
"mimeType": mime_type,
"data": validated_data, # Use the validated Base64 data
}
}
)
logger.debug(
f"Successfully added audio part (format: {audio_format})"
)
except ValueError as e:
logger.error(
f"Skipping audio part due to validation error: {e}"
)
parts.append({"text": f"[Error processing audio: {e}]"})
except Exception:
logger.exception("Unexpected error processing audio part.")
parts.append(
{"text": "[Unexpected error processing audio]"}
)
elif content_type == "input_video" and content_item.get(
"input_video"
):
video_info = content_item["input_video"]
video_data = video_info.get("data")
video_format = video_info.get("format", "").lower()
if not video_data or not video_format:
logger.warning(
"Skipping video part due to missing data or format."
)
continue
try:
validated_data = self._validate_media_data(
video_format,
video_data,
SUPPORTED_VIDEO_FORMATS,
MAX_VIDEO_SIZE_BYTES,
)
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
if not mime_type:
raise ValueError(
f"Internal error: MIME type mapping missing for {video_format}"
)
parts.append(
{
"inline_data": {
"mimeType": mime_type,
"data": validated_data,
}
}
)
logger.debug(
f"Successfully added video part (format: {video_format})"
)
except ValueError as e:
logger.error(
f"Skipping video part due to validation error: {e}"
)
parts.append({"text": f"[Error processing video: {e}]"})
except Exception:
logger.exception("Unexpected error processing video part.")
parts.append(
{"text": "[Unexpected error processing video]"}
)
else:
# Log unrecognized but present types
if content_type:
logger.warning(
f"Unsupported content type or missing data in structured content: {content_type}"
)
elif (
"content" in msg and isinstance(msg["content"], str) and msg["content"]
):
parts.extend(_process_text_with_image(msg["content"]))
elif "content" in msg and isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text" and content["text"]:
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
# Keep existing tool call processing
for tool_call in msg["tool_calls"]:
function_call = tool_call.get("function",{})
function_call["args"] = json.loads(function_call.get("arguments","{}"))
del function_call["arguments"]
function_call = tool_call.get("function", {})
# Sanitize arguments loading
arguments_str = function_call.get("arguments", "{}")
try:
function_call["args"] = json.loads(arguments_str)
except json.JSONDecodeError:
logger.warning(
f"Failed to decode tool call arguments: {arguments_str}"
)
function_call["args"] = {}
if "arguments" in function_call:
if "arguments" in function_call:
del function_call["arguments"]
parts.append({"functionCall": function_call})
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user"
@@ -158,7 +337,14 @@ class OpenAIMessageConverter(MessageConverter):
role = "model"
if parts:
if role == "system":
system_instruction_parts.extend(parts)
text_only_parts = [p for p in parts if "text" in p]
if len(text_only_parts) != len(parts):
logger.warning(
"Non-text parts found in system message; discarding them."
)
if text_only_parts:
system_instruction_parts.extend(text_only_parts)
else:
converted_messages.append({"role": role, "parts": parts})
@@ -170,4 +356,4 @@ class OpenAIMessageConverter(MessageConverter):
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction
return converted_messages, system_instruction

View File

@@ -1,12 +1,12 @@
import base64
import json
import random
import string
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
@@ -15,7 +15,9 @@ class ResponseHandler(ABC):
"""响应处理器基类"""
@abstractmethod
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False
) -> Dict[str, Any]:
pass
@@ -26,14 +28,20 @@ class GeminiResponseHandler(ResponseHandler):
self.thinking_first = True
self.thinking_status = False
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False
) -> Dict[str, Any]:
if stream:
return _handle_gemini_stream_response(response, model, stream)
return _handle_gemini_normal_response(response, model, stream)
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
def _handle_openai_stream_response(
response: Dict[str, Any], model: str, finish_reason: str
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=True, gemini_format=False
)
if not text and not tool_calls:
delta = {}
else:
@@ -50,8 +58,12 @@ def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_
}
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
def _handle_openai_normal_response(
response: Dict[str, Any], model: str, finish_reason: str
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=False, gemini_format=False
)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
@@ -60,7 +72,11 @@ def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
"message": {
"role": "assistant",
"content": text,
"tool_calls": tool_calls,
},
"finish_reason": finish_reason,
}
],
@@ -77,59 +93,67 @@ class OpenAIResponseHandler(ResponseHandler):
self.thinking_status = False
def handle_response(
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None,
) -> Optional[Dict[str, Any]]:
if stream:
return _handle_openai_stream_response(response, model, finish_reason)
return _handle_openai_normal_response(response, model, finish_reason)
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
def handle_image_chat_response(
self, image_str: str, model: str, stream=False, finish_reason="stop"
):
if stream:
return _handle_openai_stream_image_response(image_str,model,finish_reason)
return _handle_openai_normal_image_response(image_str,model,finish_reason)
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
return _handle_openai_stream_image_response(image_str, model, finish_reason)
return _handle_openai_normal_image_response(image_str, model, finish_reason)
def _handle_openai_stream_image_response(
image_str: str, model: str, finish_reason: str
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": image_str} if image_str else {},
"finish_reason": finish_reason
}]
"choices": [
{
"index": 0,
"delta": {"content": image_str} if image_str else {},
"finish_reason": finish_reason,
}
],
}
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
def _handle_openai_normal_image_response(
image_str: str, model: str, finish_reason: str
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": image_str
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": image_str},
"finish_reason": finish_reason,
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
def _extract_result(
response: Dict[str, Any],
model: str,
stream: bool = False,
gemini_format: bool = False,
) -> tuple[str, List[Dict[str, Any]]]:
text, tool_calls = "", []
if stream:
if response.get("candidates"):
@@ -145,13 +169,9 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
text = _format_execution_result(parts[0]["executableCodeResult"])
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
text = _format_execution_result(parts[0]["codeExecutionResult"])
elif "inlineData" in parts[0]:
text = _extract_image_data(parts[0])
else:
@@ -165,10 +185,10 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
@@ -186,34 +206,47 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
elif "inlineData" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
tool_calls = _extract_tool_calls(
candidate["content"]["parts"], gemini_format
)
else:
text = "暂无返回"
return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
# 将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
upload_response = image_uploader.upload(bytes_data, filename)
if upload_response.success:
text = f"\n\n![image]({upload_response.data.url})\n\n"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
def _extract_tool_calls(
parts: List[Dict[str, Any]], gemini_format: bool
) -> List[Dict[str, Any]]:
"""提取工具调用信息"""
if not parts or not isinstance(parts, list):
return []
@@ -249,8 +282,12 @@ def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> Lis
return tool_calls
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
def _handle_gemini_stream_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=stream, gemini_format=True
)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
@@ -259,8 +296,12 @@ def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream:
return response
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
def _handle_gemini_normal_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=stream, gemini_format=True
)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
@@ -278,10 +319,10 @@ def _format_code_block(code_data: dict) -> str:
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
if (
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
):
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
text += "\n\n---\n\n"

View File

@@ -206,3 +206,15 @@ def get_update_logger():
def get_scheduler_routes():
return Logger.setup_logger("scheduler_routes")
def get_message_converter_logger():
return Logger.setup_logger("message_converter")
def get_api_client_logger():
return Logger.setup_logger("api_client")
def get_openai_compatible_logger():
return Logger.setup_logger("openai_compatible")

View File

@@ -30,6 +30,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
and not request.url.path.startswith(f"/{API_VERSION}")
and not request.url.path.startswith("/health")
and not request.url.path.startswith("/hf")
and not request.url.path.startswith("/openai")
and not request.url.path.startswith("/api/version/check")
):
auth_token = request.cookies.get("auth_token")

View File

@@ -1,15 +1,16 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
import asyncio
from app.config.config import settings
from app.log.logger import get_gemini_logger
from app.core.security import SecurityService
import asyncio # 导入 asyncio
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest # 添加导入
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.core.constants import API_VERSION
# 路由设置
@@ -43,62 +44,57 @@ async def list_models(
_=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
operation_name = "list_gemini_models"
logger.info("-" * 50 + operation_name + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
# 添加搜索模型
if settings.SEARCH_MODELS:
for name in settings.SEARCH_MODELS:
model = model_mapping.get(name)
try:
api_key = await key_manager.get_first_valid_key()
if not api_key:
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
logger.info(f"Using API key: {api_key}")
models_data =await model_service.get_gemini_models(api_key)
if not models_data or "models" not in models_data:
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
models_json = deepcopy(models_data) # 操作副本以防修改原始缓存
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
def add_derived_model(base_name, suffix, display_suffix):
model = model_mapping.get(base_name)
if not model:
continue
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
return
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["name"] = f"models/{base_name}{suffix}"
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.IMAGE_MODELS:
for name in settings.IMAGE_MODELS:
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加思考模型的非思考版本
if settings.THINKING_MODELS:
for name in settings.THINKING_MODELS:
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-non-thinking"
display_name = f'{item.get("displayName")} Non Thinking'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json
# 添加衍生模型
if settings.SEARCH_MODELS:
for name in settings.SEARCH_MODELS:
add_derived_model(name, "-search", " For Search")
if settings.IMAGE_MODELS:
for name in settings.IMAGE_MODELS:
add_derived_model(name, "-image", " For Image")
if settings.THINKING_MODELS:
for name in settings.THINKING_MODELS:
add_derived_model(name, "-non-thinking", " Non Thinking")
logger.info("Gemini models list request successful")
return models_json
except HTTPException as http_exc:
# 重新抛出已知的 HTTP 异常
raise http_exc
except Exception as e:
logger.error(f"Error getting Gemini models list: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error while fetching Gemini models list"
) from e
@router.post("/models/{model_name}:generateContent")
@@ -112,25 +108,22 @@ async def generate_content(
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
):
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
"""处理 Gemini 非流式内容生成请求。"""
operation_name = "gemini_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/models/{model_name}:streamGenerateContent")
@@ -144,25 +137,24 @@ async def stream_generate_content(
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
):
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
"""处理 Gemini 流式内容生成请求。"""
operation_name = "gemini_stream_generate_content"
# 流式请求的成功/失败日志在流处理中更复杂,这里仅用上下文管理器处理启动错误
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
# 注意:流本身的错误需要在服务层或流迭代中处理,这里只返回流响应
return StreamingResponse(response_stream, media_type="text/event-stream")
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Streaming request failed") from e
@router.post("/reset-all-fail-counts")
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):

View File

@@ -0,0 +1,121 @@
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from app.config.config import settings
from app.core.security import SecurityService
from app.domain.openai_models import (
ChatRequest,
EmbeddingRequest,
ImageGenerationRequest,
)
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.log.logger import get_openai_compatible_logger
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
router = APIRouter()
logger = get_openai_compatible_logger()
# 初始化服务
security_service = SecurityService()
async def get_key_manager():
return await get_key_manager_instance()
async def get_next_working_key_wrapper(
key_manager: KeyManager = Depends(get_key_manager),
):
return await key_manager.get_next_working_key()
async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
"""获取OpenAI聊天服务实例"""
return OpenAICompatiableService(settings.BASE_URL, key_manager)
@router.get("/openai/v1/models")
async def list_models(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
"""获取可用模型列表。"""
operation_name = "list_models"
async with handle_route_errors(logger, operation_name):
logger.info("Handling models list request")
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
return await openai_service.get_models(api_key)
@router.post("/openai/v1/chat/completions")
@RetryHandler(max_retries=settings.MAX_RETRIES, key_arg="api_key")
async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
operation_name = "chat_completion"
# 检查是否为图像生成相关的聊天模型,如果是,则使用付费密钥
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
current_api_key = api_key # 保存原始key可能是普通key
if is_image_chat:
current_api_key = await key_manager.get_paid_key() # 获取付费密钥
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling chat completion request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {current_api_key}") # 使用 current_api_key
if is_image_chat:
# 图像生成聊天,调用特定服务,不处理流式
response = await openai_service.create_image_chat_completion(request, current_api_key)
return response # 直接返回结果
else:
# 普通聊天补全
response = await openai_service.create_chat_completion(request, current_api_key)
# 处理流式响应
if request.stream:
# 假设 openai_service.create_chat_completion 在流式时返回异步生成器
return StreamingResponse(response, media_type="text/event-stream")
# 非流式直接返回结果
return response
@router.post("/openai/v1/images/generations")
async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
"""处理图像生成请求。"""
operation_name = "generate_image"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling image generation request for prompt: {request.prompt}")
# 强制使用配置的模型,确保请求中包含正确的模型信息
request.model = settings.CREATE_IMAGE_MODEL
return await openai_service.generate_images(request)
@router.post("/openai/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
openai_service: OpenAICompatiableService = Depends(get_openai_service),
):
"""处理文本嵌入请求。"""
operation_name = "embedding"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
return await openai_service.create_embeddings(
input_text=request.input, model=request.model, api_key=api_key
)

View File

@@ -9,6 +9,7 @@ from app.domain.openai_models import (
ImageGenerationRequest,
)
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors # 导入共享错误处理器
from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
@@ -47,17 +48,13 @@ async def list_models(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "list_models" + "-" * 50)
logger.info("Handling models list request")
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
try:
return model_service.get_gemini_openai_models(api_key)
except Exception as e:
logger.error(f"Error getting models list: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error while fetching models list"
) from e
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
operation_name = "list_models"
async with handle_route_errors(logger, operation_name):
logger.info("Handling models list request")
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
return await model_service.get_gemini_openai_models(api_key)
@router.post("/v1/chat/completions")
@@ -70,33 +67,38 @@ async def chat_completion(
key_manager: KeyManager = Depends(get_key_manager), # 保留 key_manager 用于获取 paid_key
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
):
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
api_key = await key_manager.get_paid_key()
logger.info("-" * 50 + "chat_completion" + "-" * 50)
logger.info(f"Handling chat completion request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
operation_name = "chat_completion"
# 检查是否为图像生成相关的聊天模型
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
current_api_key = api_key # 保存原始 key
if is_image_chat:
current_api_key = await key_manager.get_paid_key() # 获取付费密钥
if not model_service.check_model_support(request.model):
raise HTTPException(
status_code=400, detail=f"Model {request.model} is not supported"
)
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling chat completion request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {current_api_key}")
try:
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
response = await chat_service.create_image_chat_completion(request, api_key)
# 检查模型支持性应在错误处理块内,以便捕获并记录错误
if not await model_service.check_model_support(request.model):
# 使用 HTTPException会被 handle_route_errors 捕获并记录
raise HTTPException(
status_code=400, detail=f"Model {request.model} is not supported"
)
if is_image_chat:
# 图像生成聊天
response = await chat_service.create_image_chat_completion(request, current_api_key)
return response # 直接返回,不处理流式
else:
response = await chat_service.create_chat_completion(request, api_key)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
logger.info("Chat completion request successful")
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
# 普通聊天补全
response = await chat_service.create_chat_completion(request, current_api_key)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
# 非流式直接返回结果
return response
@router.post("/v1/images/generations")
@@ -105,18 +107,14 @@ async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
):
logger.info("-" * 50 + "generate_image" + "-" * 50)
logger.info(f"Handling image generation request for prompt: {request.prompt}")
try:
"""处理 OpenAI 图像生成请求。"""
operation_name = "generate_image"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling image generation request for prompt: {request.prompt}")
# 注意:这里假设 image_create_service.generate_images 是同步函数
# 如果它是异步的,需要 await
response = image_create_service.generate_images(request)
logger.info("Image generation request successful")
return response
except Exception as e:
logger.error(f"Image generation request failed: {str(e)}")
raise HTTPException(
status_code=500, detail="Image generation request failed"
) from e
@router.post("/v1/embeddings")
@@ -126,19 +124,16 @@ async def embedding(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "embedding" + "-" * 50)
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
"""处理 OpenAI 文本嵌入请求。"""
operation_name = "embedding"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
response = await embedding_service.create_embedding(
input_text=request.input, model=request.model, api_key=api_key
)
logger.info("Embedding request successful")
return response
except Exception as e:
logger.error(f"Embedding request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Embedding request failed") from e
@router.get("/v1/keys/list")
@@ -147,10 +142,10 @@ async def get_keys_list(
_=Depends(security_service.verify_auth_token),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取有效和无效的API key列表"""
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
logger.info("Handling keys list request")
try:
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
operation_name = "get_keys_list"
async with handle_route_errors(logger, operation_name):
logger.info("Handling keys list request")
keys_status = await key_manager.get_keys_by_status()
return {
"status": "success",
@@ -160,8 +155,3 @@ async def get_keys_list(
},
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
}
except Exception as e:
logger.error(f"Error getting keys list: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error while fetching keys list"
) from e

View File

@@ -8,9 +8,9 @@ from fastapi.templating import Jinja2Templates
from app.core.security import verify_auth_token
from app.log.logger import get_routes_logger
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes # 新增导入 version_routes
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes
from app.service.key.key_manager import get_key_manager_instance
from app.service.stats_service import StatsService
from app.service.stats.stats_service import StatsService
logger = get_routes_logger()
@@ -31,9 +31,10 @@ def setup_routers(app: FastAPI) -> None:
app.include_router(gemini_routes.router_v1beta)
app.include_router(config_routes.router)
app.include_router(error_log_routes.router)
app.include_router(scheduler_routes.router) # 新增包含 scheduler 路由
app.include_router(stats_routes.router) # 包含 stats API 路由
app.include_router(version_routes.router) # 包含 version API 路由
app.include_router(scheduler_routes.router)
app.include_router(stats_routes.router)
app.include_router(version_routes.router)
app.include_router(openai_compatiable_routes.router)
# 添加页面路由
setup_page_routes(app)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from starlette import status
from app.core.security import verify_auth_token
from app.service.stats_service import StatsService
from app.service.stats.stats_service import StatsService
from app.log.logger import get_stats_logger
logger = get_stats_logger()

View File

@@ -81,10 +81,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]

View File

@@ -1,13 +1,17 @@
# app/services/chat_service.py
import datetime
import json
import re
import datetime # Add datetime import
import time # Add time import
import time
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.database.services import (
add_error_log,
add_request_log,
)
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
@@ -16,17 +20,16 @@ from app.log.logger import get_openai_logger
from app.service.client.api_client import GeminiApiClient
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
from app.database.services import add_error_log, add_request_log # Import add_request_log
logger = get_openai_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
for content in contents:
if "parts" in content:
if content and "parts" in content and isinstance(content["parts"], list):
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
if isinstance(part, dict) and "inline_data" in part:
return True
return False
@@ -46,9 +49,13 @@ def _build_tools(
or model.endswith("-image")
or model.endswith("-image-generation")
)
and not _has_image_parts(messages)
and not _has_media_parts(messages) # Use the updated check
):
tool["codeExecution"] = {}
logger.debug("Code execution tool enabled.")
elif _has_media_parts(messages):
logger.debug("Code execution tool disabled due to media parts presence.")
if model.endswith("-search"):
tool["googleSearch"] = {}
@@ -62,7 +69,9 @@ def _build_tools(
if item.get("type", "") == "function" and item.get("function"):
function = deepcopy(item.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get("properties", {}):
if parameters.get("type") == "object" and not parameters.get(
"properties", {}
):
function.pop("parameters", None)
function_declarations.append(function)
@@ -93,20 +102,8 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
return settings.GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
return settings.SAFETY_SETTINGS
def _build_payload(
@@ -131,9 +128,11 @@ def _build_payload(
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if request.model.endswith("-non-thinking"):
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
if request.model in settings.THINKING_BUDGET_MAP:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model,1000)}
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
}
if (
instruction
@@ -205,7 +204,7 @@ class OpenAIChatService:
try:
response = await self.api_client.generate_content(payload, model, api_key)
is_success = True
status_code = 200 # Assume 200 on success
status_code = 200
return self.response_handler.handle_response(
response, model, stream=False, finish_reason="stop"
)
@@ -218,17 +217,17 @@ class OpenAIChatService:
if match:
status_code = int(match.group(1))
else:
status_code = 500 # Default if parsing fails
status_code = 500
await add_error_log(
gemini_key=api_key, # Note: Parameter name is gemini_key in add_error_log
gemini_key=api_key,
model_name=model,
error_type="openai-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
request_msg=payload,
)
raise e # Re-throw exception
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
@@ -238,7 +237,7 @@ class OpenAIChatService:
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)
async def _handle_stream_completion(
@@ -261,6 +260,7 @@ class OpenAIChatService:
async for line in self.api_client.stream_generate_content(
payload, model, current_attempt_key
):
# print(line)
if line.startswith("data:"):
chunk = json.loads(line[6:])
openai_chunk = self.response_handler.handle_response(
@@ -293,7 +293,7 @@ class OpenAIChatService:
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
is_success = True
status_code = 200 # Assume 200 on success
status_code = 200
break # 成功后退出循环
except Exception as e:
retries += 1
@@ -307,7 +307,7 @@ class OpenAIChatService:
if match:
status_code = int(match.group(1))
else:
status_code = 500 # Default if parsing fails
status_code = 500
# Log error to error log table
await add_error_log(
@@ -316,38 +316,40 @@ class OpenAIChatService:
error_type="openai-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
request_msg=payload,
)
# Attempt to switch API Key
# Ensure key_manager is available (might need adjustment if not always passed)
if self.key_manager:
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
api_key = await self.key_manager.handle_api_failure(
current_attempt_key, retries
)
if api_key:
logger.info(f"Switched to new API key: {api_key}")
else:
logger.error(f"No valid API key available after {retries} retries.")
break # Exit loop if no key available
logger.error(
f"No valid API key available after {retries} retries."
)
break
else:
logger.error("KeyManager not available for retry logic.")
break # Exit loop if key manager is missing
logger.error("KeyManager not available for retry logic.")
break
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming."
)
break # Exit loop after max retries
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
finally:
# Log the final outcome of the streaming request
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=final_api_key, # Log the last key used
is_success=is_success, # Log the final success status
status_code=status_code, # Log the last known status code
latency_ms=latency_ms, # Log total time including retries
request_time=request_datetime
api_key=final_api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
# If the loop finished due to failure, yield error and DONE
if not is_success and retries >= max_retries:
@@ -355,9 +357,7 @@ class OpenAIChatService:
yield "data: [DONE]\n\n"
async def create_image_chat_completion(
self,
request: ChatRequest,
api_key: str
self, request: ChatRequest, api_key: str
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
@@ -367,18 +367,22 @@ class OpenAIChatService:
)
if request.stream:
return self._handle_stream_image_completion(request.model, image_res, api_key)
return self._handle_stream_image_completion(
request.model, image_res, api_key
)
else:
return await self._handle_normal_image_completion(request.model, image_res, api_key)
return await self._handle_normal_image_completion(
request.model, image_res, api_key
)
async def _handle_stream_image_completion(
self, model: str, image_data: str, api_key:str
self, model: str, image_data: str, api_key: str
) -> AsyncGenerator[str, None]:
logger.info(f"Starting stream image completion for model: {model}")
start_time = time.perf_counter()
request_datetime = datetime.datetime.now() # Although not used for DB log here
request_datetime = datetime.datetime.now()
is_success = False
status_code = None # Although not used for DB log here
status_code = None
try:
if image_data:
@@ -402,7 +406,9 @@ class OpenAIChatService:
# 如果没有文本内容如图片URL等整块输出
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
logger.info(f"Stream image completion finished successfully for model: {model}")
logger.info(
f"Stream image completion finished successfully for model: {model}"
)
is_success = True
status_code = 200
yield "data: [DONE]\n\n"
@@ -410,46 +416,51 @@ class OpenAIChatService:
is_success = False
error_log_msg = f"Stream image completion failed for model {model}: {e}"
logger.error(error_log_msg)
status_code = 500 # Default error code
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-stream", # Specific error type
error_type="openai-image-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
request_msg={
"image_data_truncated": image_data[:1000]
},
)
yield f"data: {json.dumps({'error': error_log_msg})}\n\n" # Send error to client
yield "data: [DONE]\n\n" # Still need DONE message
# Re-raising might break the stream, decide if needed
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
yield "data: [DONE]\n\n"
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}")
logger.info(
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)
async def _handle_normal_image_completion(
self, model: str, image_data: str, api_key: str # Add api_key parameter
self, model: str, image_data: str, api_key: str
) -> Dict[str, Any]:
logger.info(f"Starting normal image completion for model: {model}")
start_time = time.perf_counter()
request_datetime = datetime.datetime.now() # Although not used for DB log here
request_datetime = datetime.datetime.now()
is_success = False
status_code = None # Although not used for DB log here
status_code = None
result = None
try:
result = self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)
logger.info(f"Normal image completion finished successfully for model: {model}")
logger.info(
f"Normal image completion finished successfully for model: {model}"
)
is_success = True
status_code = 200
return result
@@ -457,26 +468,30 @@ class OpenAIChatService:
is_success = False
error_log_msg = f"Normal image completion failed for model {model}: {e}"
logger.error(error_log_msg)
status_code = 500 # Default error code
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-non-stream", # Specific error type
error_type="openai-image-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
request_msg={
"image_data_truncated": image_data[:1000]
},
)
# Re-raise the exception so the caller knows about the failure
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}")
logger.info(
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)

View File

@@ -1,11 +1,14 @@
# app/services/chat/api_client.py
from typing import Dict, Any, AsyncGenerator
from typing import Dict, Any, AsyncGenerator, Optional
import httpx
import random
from abc import ABC, abstractmethod
from app.config.config import settings
from app.log.logger import get_api_client_logger
from app.core.constants import DEFAULT_TIMEOUT
logger = get_api_client_logger()
class ApiClient(ABC):
"""API客户端基类"""
@@ -37,11 +40,41 @@ class GeminiApiClient(ApiClient):
model = model[:-20]
return model
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
"""获取可用的 Gemini 模型列表"""
timeout = httpx.Timeout(timeout=5)
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models?key={api_key}"
try:
response = await client.get(url)
response.raise_for_status() # 如果状态码不是 2xx则引发 HTTPStatusError
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"获取模型列表失败: {e.response.status_code}")
logger.error(e.response.text)
# 返回 None 而不是抛出异常,以便上层处理
return None
except httpx.RequestError as e:
logger.error(f"请求模型列表失败: {e}")
# 返回 None 而不是抛出异常
return None
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
if response.status_code != 200:
@@ -53,7 +86,12 @@ class GeminiApiClient(ApiClient):
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response:
if response.status_code != 200:
@@ -62,3 +100,96 @@ class GeminiApiClient(ApiClient):
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
async for line in response.aiter_lines():
yield line
class OpenaiApiClient(ApiClient):
"""OpenAI API客户端"""
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
self.base_url = base_url
self.timeout = timeout
async def get_models(self, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/openai/models"
headers = {"Authorization": f"Bearer {api_key}"}
response = await client.get(url, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
async for line in response.aiter_lines():
yield line
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/embeddings"
headers = {"Authorization": f"Bearer {api_key}"}
payload = {
"input": input,
"model": model,
}
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
proxy_to_use = None
if settings.PROXIES:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy: {proxy_to_use}")
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/images/generations"
headers = {"Authorization": f"Bearer {api_key}"}
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()

View File

@@ -1,50 +1,47 @@
from datetime import datetime, timezone
from typing import Any, Dict, Optional
import requests
from app.config.config import settings
from app.log.logger import get_model_logger
from app.service.client.api_client import GeminiApiClient
logger = get_model_logger()
class ModelService:
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
url = f"{settings.BASE_URL}/models?key={api_key}"
async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
"""使用 GeminiApiClient 获取并过滤模型列表"""
api_client = GeminiApiClient(base_url=settings.BASE_URL) # 实例化客户端
gemini_models = await api_client.get_models(api_key)
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
filtered_models_list = []
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
if model_id not in settings.FILTERED_MODELS:
filtered_models_list.append(model)
else:
logger.debug(f"Filtered out model: {model_id}")
gemini_models["models"] = filtered_models_list
return gemini_models
else:
logger.error(f"Error: {response.status_code}")
logger.error(response.text)
return None
except requests.RequestException as e:
logger.error(f"Request failed: {e}")
if gemini_models is None:
logger.error("从 API 客户端获取模型列表失败。")
return None
def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
try:
gemini_models = self.get_gemini_models(api_key)
return self.convert_to_openai_models_format(gemini_models)
except requests.RequestException as e:
logger.error(f"Request failed: {e}")
filtered_models_list = []
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
if model_id not in settings.FILTERED_MODELS:
filtered_models_list.append(model)
else:
logger.debug(f"Filtered out model: {model_id}")
gemini_models["models"] = filtered_models_list
return gemini_models
except Exception as e:
logger.error(f"处理模型列表时出错: {e}")
return None
def convert_to_openai_models_format(
async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
"""获取 Gemini 模型并转换为 OpenAI 格式"""
gemini_models = await self.get_gemini_models(api_key)
if gemini_models is None:
return None
return await self.convert_to_openai_models_format(gemini_models)
async def convert_to_openai_models_format(
self, gemini_models: Dict[str, Any]
) -> Dict[str, Any]:
openai_format = {"object": "list", "data": [], "success": True}
@@ -81,7 +78,7 @@ class ModelService:
openai_format["data"].append(image_model)
return openai_format
def check_model_support(self, model: str) -> bool:
async def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False

View File

@@ -0,0 +1,197 @@
import datetime
import json
import re
import time
from typing import Any, AsyncGenerator, Dict, Union
from app.config.config import settings
from app.database.services import (
add_error_log,
add_request_log,
)
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.service.client.api_client import OpenaiApiClient
from app.service.key.key_manager import KeyManager
from app.log.logger import get_openai_compatible_logger
logger = get_openai_compatible_logger()
class OpenAICompatiableService:
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.key_manager = key_manager
self.base_url = base_url
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
async def get_models(self, api_key: str) -> Dict[str, Any]:
return await self.api_client.get_models(api_key)
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
request_dict = request.model_dump()
# 移除值为null的
request_dict = {k: v for k, v in request_dict.items() if v is not None}
del request_dict["top_k"] # 删除top_k参数目前不支持该参数
if request.stream:
return self._handle_stream_completion(request.model, request_dict, api_key)
return await self._handle_normal_completion(request.model, request_dict, api_key)
async def generate_images(
self,
request: ImageGenerationRequest,
) -> Dict[str, Any]:
"""生成图片"""
request_dict = request.model_dump()
# 移除值为null的
request_dict = {k: v for k, v in request_dict.items() if v is not None}
api_key = settings.PAID_KEY
return await self.api_client.generate_images(request_dict, api_key)
async def create_embeddings(
self,
input_text: str,
model: str,
api_key: str,
) -> Dict[str, Any]:
"""创建嵌入"""
return await self.api_client.create_embeddings(input_text, model, api_key)
async def _handle_normal_completion(
self, model: str, request: dict, api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.generate_content(request, api_key)
is_success = True
status_code = 200
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Normal API call failed with error: {error_log_msg}")
# Try to parse status code from exception
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-compatiable-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=request,
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
async def _handle_stream_completion(
self, model: str, payload: dict, api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = settings.MAX_RETRIES
is_success = False
status_code = None
final_api_key = api_key
while retries < max_retries:
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
current_attempt_key = api_key
final_api_key = current_attempt_key
try:
async for line in self.api_client.stream_generate_content(
payload, current_attempt_key
):
if line.startswith("data:"):
# print(line)
yield line + "\n\n"
logger.info("Streaming completed successfully")
is_success = True
status_code = 200
break # 成功后退出循环
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
)
# Parse error code for logging
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
# Log error to error log table
await add_error_log(
gemini_key=current_attempt_key,
model_name=model,
error_type="openai-compatiable-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
)
# Attempt to switch API Key
# Ensure key_manager is available (might need adjustment if not always passed)
if self.key_manager:
api_key = await self.key_manager.handle_api_failure(
current_attempt_key, retries
)
if api_key:
logger.info(f"Switched to new API key: {api_key}")
else:
logger.error(
f"No valid API key available after {retries} retries."
)
break
else:
logger.error("KeyManager not available for retry logic.")
break
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
finally:
# Log the final outcome of the streaming request
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=final_api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
# If the loop finished due to failure, yield error and DONE
if not is_success and retries >= max_retries:
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"

View File

@@ -1,3 +1,7 @@
// 将需要在外部函数访问的 DOM 元素移到外部
const safetySettingsContainer = document.getElementById('SAFETY_SETTINGS_container');
const thinkingModelsContainer = document.getElementById('THINKING_MODELS_container');
document.addEventListener('DOMContentLoaded', function() {
// 初始化配置
initConfig();
@@ -63,13 +67,29 @@ document.addEventListener('DOMContentLoaded', function() {
const cancelBulkDeleteApiKeyBtn = document.getElementById('cancelBulkDeleteApiKeyBtn'); // 新增
const confirmBulkDeleteApiKeyBtn = document.getElementById('confirmBulkDeleteApiKeyBtn'); // 新增
const bulkDeleteApiKeyInput = document.getElementById('bulkDeleteApiKeyInput'); // 新增
// --- 新增Proxy 模态框相关 ---
const proxyModal = document.getElementById('proxyModal');
const addProxyBtn = document.getElementById('addProxyBtn'); // Changed from bulkAddProxyBtn
const closeProxyModalBtn = document.getElementById('closeProxyModalBtn');
const cancelAddProxyBtn = document.getElementById('cancelAddProxyBtn');
const confirmAddProxyBtn = document.getElementById('confirmAddProxyBtn');
const proxyBulkInput = document.getElementById('proxyBulkInput');
const bulkDeleteProxyBtn = document.getElementById('bulkDeleteProxyBtn'); // 新增
const bulkDeleteProxyModal = document.getElementById('bulkDeleteProxyModal'); // 新增
const closeBulkDeleteProxyModalBtn = document.getElementById('closeBulkDeleteProxyModalBtn'); // 新增
const cancelBulkDeleteProxyBtn = document.getElementById('cancelBulkDeleteProxyBtn'); // 新增
const confirmBulkDeleteProxyBtn = document.getElementById('confirmBulkDeleteProxyBtn'); // 新增
const bulkDeleteProxyInput = document.getElementById('bulkDeleteProxyInput'); // 新增
// --- 结束Proxy 模态框相关 ---
// --- 新增:重置确认模态框相关 ---
const resetConfirmModal = document.getElementById('resetConfirmModal');
const closeResetModalBtn = document.getElementById('closeResetModalBtn');
const cancelResetBtn = document.getElementById('cancelResetBtn');
const confirmResetBtn = document.getElementById('confirmResetBtn');
// --- 结束:新增 ---
// const safetySettingsContainer = document.getElementById('SAFETY_SETTINGS_container'); // Moved outside
// 打开模态框
@@ -111,8 +131,14 @@ document.addEventListener('DOMContentLoaded', function() {
if (event.target == bulkDeleteApiKeyModal) { // 新增对批量删除模态框的处理
bulkDeleteApiKeyModal.classList.remove('show');
}
if (event.target == proxyModal) { // 新增对代理模态框的处理
proxyModal.classList.remove('show');
}
if (event.target == bulkDeleteProxyModal) { // 新增对批量删除代理模态框的处理
bulkDeleteProxyModal.classList.remove('show');
}
});
// 确认添加 API Key
if (confirmAddApiKeyBtn) {
confirmAddApiKeyBtn.addEventListener('click', handleBulkAddApiKeys);
@@ -158,7 +184,77 @@ document.addEventListener('DOMContentLoaded', function() {
}
// --- 结束:批量删除 API Key 相关 ---
// --- 结束API Key 相关 ---
// --- 新增Proxy 模态框事件 ---
// 打开模态框 (Changed event listener to addProxyBtn)
if (addProxyBtn) {
addProxyBtn.addEventListener('click', () => {
if (proxyModal) {
proxyModal.classList.add('show');
}
if (proxyBulkInput) proxyBulkInput.value = ''; // 清空输入框
});
}
// 关闭模态框 (X 按钮)
if (closeProxyModalBtn) {
closeProxyModalBtn.addEventListener('click', () => {
if (proxyModal) {
proxyModal.classList.remove('show');
}
});
}
// 关闭模态框 (取消按钮)
if (cancelAddProxyBtn) {
cancelAddProxyBtn.addEventListener('click', () => {
if (proxyModal) {
proxyModal.classList.remove('show');
}
});
}
// 确认添加 Proxy
if (confirmAddProxyBtn) {
confirmAddProxyBtn.addEventListener('click', handleBulkAddProxies);
}
// --- 结束Proxy 模态框事件 ---
// --- 新增:批量删除 Proxy 相关事件 ---
// 打开批量删除模态框
if (bulkDeleteProxyBtn) {
bulkDeleteProxyBtn.addEventListener('click', () => {
if (bulkDeleteProxyModal) {
bulkDeleteProxyModal.classList.add('show');
}
if (bulkDeleteProxyInput) bulkDeleteProxyInput.value = ''; // 清空输入框
});
}
// 关闭批量删除模态框 (X 按钮)
if (closeBulkDeleteProxyModalBtn) {
closeBulkDeleteProxyModalBtn.addEventListener('click', () => {
if (bulkDeleteProxyModal) {
bulkDeleteProxyModal.classList.remove('show');
}
});
}
// 关闭批量删除模态框 (取消按钮)
if (cancelBulkDeleteProxyBtn) {
cancelBulkDeleteProxyBtn.addEventListener('click', () => {
if (bulkDeleteProxyModal) {
bulkDeleteProxyModal.classList.remove('show');
}
});
}
// 确认批量删除 Proxy
if (confirmBulkDeleteProxyBtn) {
confirmBulkDeleteProxyBtn.addEventListener('click', handleBulkDeleteProxies);
}
// --- 结束:批量删除 Proxy 相关 ---
// --- 新增:重置确认模态框事件监听 (移到 DOMContentLoaded 内部) ---
if (closeResetModalBtn) {
closeResetModalBtn.addEventListener('click', () => {
@@ -206,7 +302,7 @@ document.addEventListener('DOMContentLoaded', function() {
// --- 结束:思考模型预算映射相关 ---
// 添加事件委托,处理动态添加的 THINKING_MODELS 输入框的 input 事件
const thinkingModelsContainer = document.getElementById('THINKING_MODELS_container');
// const thinkingModelsContainer = document.getElementById('THINKING_MODELS_container'); // Moved outside
if (thinkingModelsContainer) {
thinkingModelsContainer.addEventListener('input', function(event) {
if (event.target && event.target.classList.contains('array-input') && event.target.closest('.array-item[data-model-id]')) {
@@ -220,6 +316,12 @@ document.addEventListener('DOMContentLoaded', function() {
});
}
// --- 新增:安全设置相关 ---
const addSafetySettingBtn = document.getElementById('addSafetySettingBtn');
if (addSafetySettingBtn) {
addSafetySettingBtn.addEventListener('click', () => addSafetySettingItem());
}
// --- 结束:安全设置相关 ---
}); // <-- DOMContentLoaded 结束括号
@@ -265,6 +367,10 @@ async function initConfig() {
if (!config.FILTERED_MODELS || !Array.isArray(config.FILTERED_MODELS) || config.FILTERED_MODELS.length === 0) {
config.FILTERED_MODELS = ['gemini-1.0-pro-latest'];
}
// --- 新增:处理 PROXIES 默认值 ---
if (!config.PROXIES || !Array.isArray(config.PROXIES)) {
config.PROXIES = []; // 默认为空数组
}
// --- 新增:处理新字段的默认值 ---
if (!config.THINKING_MODELS || !Array.isArray(config.THINKING_MODELS)) {
config.THINKING_MODELS = []; // 默认为空数组
@@ -272,7 +378,11 @@ async function initConfig() {
if (!config.THINKING_BUDGET_MAP || typeof config.THINKING_BUDGET_MAP !== 'object' || config.THINKING_BUDGET_MAP === null) {
config.THINKING_BUDGET_MAP = {}; // 默认为空对象
}
// --- 结束:处理新字段的默认值 ---
// --- 新增:处理 SAFETY_SETTINGS 默认值 ---
if (!config.SAFETY_SETTINGS || !Array.isArray(config.SAFETY_SETTINGS)) {
config.SAFETY_SETTINGS = []; // 默认为空数组
}
// --- 结束:处理 SAFETY_SETTINGS 默认值 ---
populateForm(config);
@@ -296,6 +406,7 @@ async function initConfig() {
SEARCH_MODELS: ['gemini-1.5-flash-latest'],
FILTERED_MODELS: ['gemini-1.0-pro-latest'],
UPLOAD_PROVIDER: 'smms',
PROXIES: [], // 添加默认值
THINKING_MODELS: [],
THINKING_BUDGET_MAP: {}
};
@@ -410,6 +521,24 @@ function populateForm(config) {
if (uploadProvider) {
toggleProviderConfig(uploadProvider.value);
}
// --- 新增:填充 SAFETY_SETTINGS ---
let safetyItemsAdded = false;
if (safetySettingsContainer && Array.isArray(config.SAFETY_SETTINGS)) {
config.SAFETY_SETTINGS.forEach(setting => {
if (setting && typeof setting === 'object' && setting.category && setting.threshold) {
addSafetySettingItem(setting.category, setting.threshold);
safetyItemsAdded = true;
} else {
console.warn("Invalid safety setting item found:", setting);
}
});
}
// 如果没有添加任何安全设置项,则显示占位符
if (safetySettingsContainer && !safetyItemsAdded) {
safetySettingsContainer.innerHTML = '<div class="text-gray-500 text-sm italic">定义模型的安全过滤阈值。</div>';
}
// --- 结束:填充 SAFETY_SETTINGS ---
}
// --- 新增:处理批量添加 API Key 的逻辑 ---
@@ -521,6 +650,92 @@ function handleBulkDeleteApiKeys() {
bulkDeleteTextarea.value = '';
}
// --- 新增:处理批量添加 Proxy 的逻辑 ---
function handleBulkAddProxies() {
const proxyBulkInput = document.getElementById('proxyBulkInput');
const proxyContainer = document.getElementById('PROXIES_container');
const proxyModal = document.getElementById('proxyModal');
if (!proxyBulkInput || !proxyContainer || !proxyModal) return;
const bulkText = proxyBulkInput.value;
// 匹配 http(s):// 或 socks5:// 格式的代理,允许包含用户名密码
const proxyRegex = /(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g;
const extractedProxies = bulkText.match(proxyRegex) || [];
// 获取当前已有的 proxies
const currentProxyInputs = proxyContainer.querySelectorAll('.array-input');
const currentProxies = Array.from(currentProxyInputs).map(input => input.value).filter(proxy => proxy.trim() !== '');
// 合并并去重
const combinedProxies = new Set([...currentProxies, ...extractedProxies]);
const uniqueProxies = Array.from(combinedProxies);
// 清空现有列表显示
const existingItems = proxyContainer.querySelectorAll('.array-item');
existingItems.forEach(item => item.remove());
// 重新填充列表
uniqueProxies.forEach(proxy => {
addArrayItemWithValue('PROXIES', proxy);
});
// 关闭模态框
proxyModal.classList.remove('show');
showNotification(`添加/更新了 ${uniqueProxies.length} 个唯一代理`, 'success');
}
// --- 结束:处理批量添加 Proxy 的逻辑 ---
// --- 新增:处理批量删除 Proxy 的逻辑 ---
function handleBulkDeleteProxies() {
const bulkDeleteTextarea = document.getElementById('bulkDeleteProxyInput');
const proxyContainer = document.getElementById('PROXIES_container');
const bulkDeleteModal = document.getElementById('bulkDeleteProxyModal');
if (!bulkDeleteTextarea || !proxyContainer || !bulkDeleteModal) return;
const bulkText = bulkDeleteTextarea.value;
if (!bulkText.trim()) {
showNotification('请粘贴需要删除的代理地址', 'warning');
return;
}
// 使用与添加时相同的正则表达式来提取要删除的代理
const proxyRegex = /(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g;
const proxiesToDelete = new Set(bulkText.match(proxyRegex) || []); // 使用 Set 进行高效查找
if (proxiesToDelete.size === 0) {
showNotification('未在输入内容中提取到有效的代理地址格式', 'warning');
return;
}
const proxyItems = proxyContainer.querySelectorAll('.array-item');
let deleteCount = 0;
proxyItems.forEach(item => {
const input = item.querySelector('.array-input');
// 检查输入框是否存在及其值是否在要删除的集合中
if (input && proxiesToDelete.has(input.value)) {
item.remove(); // 删除整个数组项元素
deleteCount++;
}
});
// 关闭模态框
bulkDeleteModal.classList.remove('show');
// 提供反馈
if (deleteCount > 0) {
showNotification(`成功删除了 ${deleteCount} 个匹配的代理`, 'success');
} else {
showNotification('列表中未找到您输入的任何代理进行删除', 'info');
}
// 处理后清空文本区域
bulkDeleteTextarea.value = '';
}
// --- 结束:处理批量删除 Proxy 的逻辑 ---
// 切换标签
function switchTab(tabId) {
// 更新标签按钮状态
@@ -781,6 +996,23 @@ function collectFormData() {
}
// --- 结束:处理 THINKING_BUDGET_MAP ---
// --- 新增:处理 SAFETY_SETTINGS ---
if (safetySettingsContainer) {
formData['SAFETY_SETTINGS'] = [];
const settingItems = safetySettingsContainer.querySelectorAll('.safety-setting-item');
settingItems.forEach(item => {
const categorySelect = item.querySelector('.safety-category-select');
const thresholdSelect = item.querySelector('.safety-threshold-select');
if (categorySelect && thresholdSelect && categorySelect.value && thresholdSelect.value) {
formData['SAFETY_SETTINGS'].push({
category: categorySelect.value,
threshold: thresholdSelect.value
});
}
});
}
// --- 结束:处理 SAFETY_SETTINGS ---
return formData;
}
@@ -975,10 +1207,6 @@ function generateRandomToken() {
}
// --- 结束:生成随机令牌函数 ---
// --- 修改:添加思考模型预算映射项 (现在由添加思考模型触发) ---
// function addBudgetMapItem() {
// // 不再需要手动添加
// }
// Deprecated: This function is now effectively replaced by createAndAppendBudgetMapItem
// for the initial population logic. It delegates to the new function if called.
@@ -988,3 +1216,86 @@ function addBudgetMapItemWithValue(mapKey, mapValue, modelId) {
createAndAppendBudgetMapItem(mapKey, mapValue, modelId);
}
/* --- 结束:(addBudgetMapItemWithValue 已弃用) --- */
// --- 新增:添加安全设置项的函数 ---
function addSafetySettingItem(category = '', threshold = '') {
const container = document.getElementById('SAFETY_SETTINGS_container');
if (!container) {
console.error("Cannot add safety setting: SAFETY_SETTINGS_container not found!");
return;
}
// 如果容器当前只有占位符,则清除它
const placeholder = container.querySelector('.text-gray-500.italic');
if (placeholder && container.children.length === 1 && container.firstChild === placeholder) {
container.innerHTML = '';
}
const harmCategories = [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY" // 根据需要添加或移除
];
const harmThresholds = [
"BLOCK_NONE",
"BLOCK_LOW_AND_ABOVE",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_ONLY_HIGH",
"OFF" // 根据 Google API 文档添加或移除
];
const settingItem = document.createElement('div');
settingItem.className = 'safety-setting-item flex items-center mb-2 gap-2';
// Category Select
const categorySelect = document.createElement('select');
categorySelect.className = 'safety-category-select flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 bg-white';
harmCategories.forEach(cat => {
const option = document.createElement('option');
option.value = cat;
option.textContent = cat.replace('HARM_CATEGORY_', ''); // 显示更友好的名称
if (cat === category) {
option.selected = true;
}
categorySelect.appendChild(option);
});
// Threshold Select
const thresholdSelect = document.createElement('select');
thresholdSelect.className = 'safety-threshold-select w-48 px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 bg-white';
harmThresholds.forEach(thr => {
const option = document.createElement('option');
option.value = thr;
option.textContent = thr.replace('BLOCK_', '').replace('_AND_ABOVE', '+'); // 简化显示
if (thr === threshold) {
option.selected = true;
}
thresholdSelect.appendChild(option);
});
// Remove Button
const removeBtn = document.createElement('button');
removeBtn.type = 'button';
removeBtn.className = 'remove-btn text-gray-400 hover:text-red-500 focus:outline-none transition-colors duration-150';
removeBtn.innerHTML = '<i class="fas fa-trash-alt"></i>';
removeBtn.title = '删除此设置';
removeBtn.addEventListener('click', function() {
const currentItem = this.closest('.safety-setting-item');
currentItem.remove();
// 检查容器是否为空,如果是,则添加回占位符
if (container.children.length === 0) {
container.innerHTML = '<div class="text-gray-500 text-sm italic">定义模型的安全过滤阈值。</div>';
}
});
settingItem.appendChild(categorySelect);
settingItem.appendChild(thresholdSelect);
settingItem.appendChild(removeBtn);
container.appendChild(settingItem);
}
// --- 结束:添加安全设置项的函数 ---

View File

@@ -17,13 +17,27 @@ self.addEventListener('install', event => {
self.addEventListener('fetch', event => {
event.respondWith(
caches.match(event.request)
.then(response => {
if (response) {
return response;
}
return fetch(event.request);
})
caches.open(CACHE_NAME).then(cache => {
// 1. 尝试从缓存获取
return cache.match(event.request).then(responseFromCache => {
// 2. 同时从网络获取 (后台进行)
const fetchPromise = fetch(event.request).then(responseFromNetwork => {
// 3. 网络请求成功,更新缓存
cache.put(event.request, responseFromNetwork.clone());
return responseFromNetwork;
}).catch(err => {
// 网络请求失败时,可以选择记录错误或不执行任何操作
console.error('Network fetch failed:', err);
// 确保即使网络失败,如果缓存存在,我们仍然返回缓存
// 如果缓存也不存在,则此 Promise 会 reject
throw err;
});
// 4. 如果缓存存在,立即返回缓存;否则等待网络响应
// 后台的网络请求仍在进行,用于更新缓存
return responseFromCache || fetchPromise;
});
})
);
});

View File

@@ -182,6 +182,22 @@
<input type="number" id="MAX_RETRIES" name="MAX_RETRIES" min="0" max="10" class="w-full px-4 py-3 rounded-lg border border-gray-300 focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50">
<small class="text-gray-500 mt-1 block">API请求失败后的最大重试次数</small>
</div>
<!-- 代理服务器列表 -->
<div class="mb-6">
<label for="PROXIES" class="block font-semibold mb-2 text-gray-700">代理服务器列表</label>
<div class="array-container bg-white rounded-lg border border-gray-200 p-4 mb-2" id="PROXIES_container">
<!-- 代理项将在这里动态添加 -->
</div>
<div class="flex justify-end gap-2">
<button type="button" class="bg-danger-600 hover:bg-danger-700 text-white px-4 py-2 rounded-lg font-medium transition-all duration-200 flex items-center gap-2" id="bulkDeleteProxyBtn">
<i class="fas fa-trash-alt"></i> 删除代理
</button>
<button type="button" class="bg-primary-600 hover:bg-primary-700 text-white px-4 py-2 rounded-lg font-medium transition-all duration-200 flex items-center gap-2" id="addProxyBtn">
<i class="fas fa-plus"></i> 添加代理
</button>
</div>
<small class="text-gray-500 mt-1 block">代理服务器列表,支持 http 和 socks5 格式,例如: http://user:pass@host:port 或 socks5://host:port。点击按钮可批量添加或删除。</small>
</div>
</div>
<!-- 模型相关配置 -->
@@ -295,6 +311,20 @@
</div> -->
<small class="text-gray-500 mt-1 block">为每个思考模型设置预算(整数,最大值 24576此项与上方模型列表自动关联。</small>
</div>
<!-- 安全设置 -->
<div class="mb-6">
<label for="SAFETY_SETTINGS" class="block font-semibold mb-2 text-gray-700">安全设置 (Safety Settings)</label>
<div class="bg-white rounded-lg border border-gray-200 p-4 mb-2 space-y-3" id="SAFETY_SETTINGS_container">
<!-- 安全设置项将在这里动态添加 -->
<div class="text-gray-500 text-sm italic">定义模型的安全过滤阈值。</div>
</div>
<div class="flex justify-end">
<button type="button" class="bg-primary-600 hover:bg-primary-700 text-white px-4 py-2 rounded-lg font-medium transition-all duration-200 flex items-center gap-2" id="addSafetySettingBtn">
<i class="fas fa-plus"></i> 添加安全设置
</button>
</div>
<small class="text-gray-500 mt-1 block">配置模型的安全过滤级别,例如 HARM_CATEGORY_HARASSMENT: BLOCK_NONE。</small>
</div>
</div>
<!-- 图像生成相关配置 -->
@@ -511,6 +541,43 @@
</div>
</div>
</div>
<!-- Proxy Add Modal -->
<div id="proxyModal" class="modal">
<div class="w-full max-w-lg mx-auto bg-white rounded-2xl shadow-2xl overflow-hidden animate-fade-in">
<div class="p-6">
<div class="flex justify-between items-center mb-4">
<h2 class="text-xl font-bold text-gray-800">批量添加代理服务器</h2>
<button id="closeProxyModalBtn" class="text-gray-400 hover:text-gray-600 text-xl">&times;</button>
</div>
<p class="text-gray-600 mb-4">每行粘贴一个或多个代理地址,将自动提取有效地址并去重。</p>
<textarea id="proxyBulkInput" rows="10" placeholder="在此处粘贴代理地址 (例如 http://user:pass@host:port 或 socks5://host:port)..." class="w-full px-4 py-3 rounded-lg border border-gray-300 focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 font-mono text-sm"></textarea>
<div class="flex justify-end gap-3 mt-6">
<button type="button" id="confirmAddProxyBtn" class="bg-primary-600 hover:bg-primary-700 text-white px-6 py-2 rounded-lg font-medium transition">确认添加</button>
<button type="button" id="cancelAddProxyBtn" class="bg-gray-200 hover:bg-gray-300 text-gray-700 px-6 py-2 rounded-lg font-medium transition">取消</button>
</div>
</div>
</div>
</div>
<!-- Bulk Delete Proxy Modal -->
<div id="bulkDeleteProxyModal" class="modal">
<div class="w-full max-w-lg mx-auto bg-white rounded-2xl shadow-2xl overflow-hidden animate-fade-in">
<div class="p-6">
<div class="flex justify-between items-center mb-4">
<h2 class="text-xl font-bold text-gray-800">批量删除代理服务器</h2>
<button id="closeBulkDeleteProxyModalBtn" class="text-gray-400 hover:text-gray-600 text-xl">&times;</button>
</div>
<p class="text-gray-600 mb-4">每行粘贴一个或多个代理地址,将自动提取有效地址并从列表中删除。</p>
<textarea id="bulkDeleteProxyInput" rows="10" placeholder="在此处粘贴要删除的代理地址..." class="w-full px-4 py-3 rounded-lg border border-gray-300 focus:border-danger-500 focus:ring focus:ring-danger-200 focus:ring-opacity-50 font-mono text-sm"></textarea>
<div class="flex justify-end gap-3 mt-6">
<button type="button" id="confirmBulkDeleteProxyBtn" class="bg-danger-600 hover:bg-danger-700 text-white px-6 py-2 rounded-lg font-medium transition">确认删除</button>
<button type="button" id="cancelBulkDeleteProxyBtn" class="bg-gray-200 hover:bg-gray-300 text-gray-700 px-6 py-2 rounded-lg font-medium transition">取消</button>
</div>
</div>
</div>
</div>
<!-- Reset Confirmation Modal -->
<div id="resetConfirmModal" class="modal">
<div class="w-full max-w-md mx-auto bg-white rounded-2xl shadow-2xl overflow-hidden animate-fade-in">

View File

@@ -1,5 +1,5 @@
fastapi
httpx
httpx[socks]
openai
pydantic
pydantic_settings
@@ -16,6 +16,5 @@ sqlalchemy
aiomysql
databases
python-dotenv
apscheduler # 添加定时任务库
apscheduler
packaging