mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 06:11:32 +08:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
150824938c | ||
|
|
ccaea40281 | ||
|
|
9d8e77c9f7 | ||
|
|
19941f7f50 | ||
|
|
d6981c204a | ||
|
|
d386cc7180 | ||
|
|
bed3647424 | ||
|
|
95b5acad66 | ||
|
|
68b65814bc | ||
|
|
88f5b33018 | ||
|
|
8c62c8121d | ||
|
|
05762cb6a5 | ||
|
|
78f38cc981 | ||
|
|
79f47c315e | ||
|
|
708fb1604b | ||
|
|
7dbd3ad693 | ||
|
|
67dd1af583 | ||
|
|
e104a50cf4 | ||
|
|
6b9647813b | ||
|
|
f863e3065b | ||
|
|
1314e0ee09 | ||
|
|
81d92370ad | ||
|
|
5f6eba62cc | ||
|
|
a8a265c2a7 | ||
|
|
ee21e50305 | ||
|
|
611559d298 | ||
|
|
b0127e6fc2 | ||
|
|
1d15a21ce5 | ||
|
|
c206aa8e4a | ||
|
|
3f040b7075 | ||
|
|
1771555fe9 | ||
|
|
8711088ebc | ||
|
|
bb6c629aef | ||
|
|
d06e418a61 | ||
|
|
fd39c2c9cb |
20
.env.example
20
.env.example
@@ -14,11 +14,11 @@ AUTH_TOKEN=sk-123456
|
||||
VERTEX_API_KEYS=["AQ.Abxxxxxxxxxxxxxxxxxxx"]
|
||||
# For Vertex AI Platform Express API Base URL
|
||||
VERTEX_EXPRESS_BASE_URL=https://aiplatform.googleapis.com/v1beta1/publishers/google
|
||||
TEST_MODEL=gemini-1.5-flash
|
||||
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
TEST_MODEL=gemini-2.5-flash-lite
|
||||
THINKING_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash": -1}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
|
||||
SEARCH_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
|
||||
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
||||
# 是否启用网址上下文,默认启用
|
||||
URL_CONTEXT_ENABLED=false
|
||||
@@ -44,9 +44,17 @@ CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||
UPLOAD_PROVIDER=smms
|
||||
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
PICGO_API_KEY=xxxx
|
||||
PICGO_API_URL=https://www.picgo.net/api/1/upload
|
||||
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
|
||||
# 阿里云OSS配置
|
||||
OSS_ENDPOINT=oss-cn-shanghai.aliyuncs.com
|
||||
OSS_ENDPOINT_INNER=oss-cn-shanghai-internal.aliyuncs.com
|
||||
OSS_ACCESS_KEY=LTAI5txxxxxxxxxxxxxxxx
|
||||
OSS_ACCESS_KEY_SECRET=yXxxxxxxxxxxxxxxxxxxxxx
|
||||
OSS_BUCKET_NAME=your-bucket-name
|
||||
OSS_REGION=cn-shanghai
|
||||
##########################################################################
|
||||
#########################stream_optimizer 相关配置########################
|
||||
STREAM_OPTIMIZER_ENABLED=false
|
||||
@@ -59,6 +67,8 @@ STREAM_CHUNK_SIZE=5
|
||||
######################### 日志配置 #######################################
|
||||
# 日志级别 (debug, info, warning, error, critical),默认为 info
|
||||
LOG_LEVEL=info
|
||||
# 是否记录错误日志的请求体(可能包含敏感信息),默认 false
|
||||
ERROR_LOG_RECORD_REQUEST_BODY=false
|
||||
# 是否开启自动删除错误日志
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED=true
|
||||
# 自动删除多少天前的错误日志 (1, 7, 30)
|
||||
|
||||
@@ -8,6 +8,9 @@ COPY ./VERSION /app
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY ./app /app/app
|
||||
ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV TZ='Asia/Shanghai'
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
22
README.md
22
README.md
@@ -137,6 +137,8 @@ app/
|
||||
|
||||
### Gemini API Format (`/gemini/v1beta`)
|
||||
|
||||
This endpoint is directly forwarded to official Gemini API format endpoint, without advanced features.
|
||||
|
||||
* `GET /models`: List available Gemini models.
|
||||
* `POST /models/{model_name}:generateContent`: Generate content.
|
||||
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation.
|
||||
@@ -145,6 +147,8 @@ app/
|
||||
|
||||
#### Hugging Face (HF) Compatible
|
||||
|
||||
If you want to use advanced features, like fake streaming, please use this endpoint.
|
||||
|
||||
* `GET /hf/v1/models`: List models.
|
||||
* `POST /hf/v1/chat/completions`: Chat completion.
|
||||
* `POST /hf/v1/embeddings`: Create text embeddings.
|
||||
@@ -152,6 +156,8 @@ app/
|
||||
|
||||
#### Standard OpenAI
|
||||
|
||||
This endpoint is directly forwarded to official OpenAI Compatible API format endpoint, without advanced features.
|
||||
|
||||
* `GET /openai/v1/models`: List models.
|
||||
* `POST /openai/v1/chat/completions`: Chat completion (Recommended).
|
||||
* `POST /openai/v1/embeddings`: Create text embeddings.
|
||||
@@ -178,9 +184,9 @@ app/
|
||||
| `ALLOWED_TOKENS` | **Required**, list of access tokens | `[]` |
|
||||
| `AUTH_TOKEN` | Super admin token, defaults to the first of `ALLOWED_TOKENS` | `sk-123456` |
|
||||
| `ADMIN_SESSION_EXPIRE` | Admin session expiration time in seconds (5 minutes to 24 hours) | `3600` |
|
||||
| `TEST_MODEL` | Model for testing key validity | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | Models supporting image generation | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | Models supporting web search | `["gemini-2.0-flash-exp"]` |
|
||||
| `TEST_MODEL` | Model for testing key validity | `gemini-2.5-flash-lite` |
|
||||
| `IMAGE_MODELS` | Models supporting image generation | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
|
||||
| `SEARCH_MODELS` | Models supporting web search | `["gemini-2.5-flash","gemini-2.5-pro"]` |
|
||||
| `FILTERED_MODELS` | Disabled models | `[]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | Enable code execution tool | `false` |
|
||||
| `SHOW_SEARCH_LINK` | Display search result links in response | `true` |
|
||||
@@ -199,6 +205,7 @@ app/
|
||||
| `PROXIES` | List of proxy servers | `[]` |
|
||||
| **Logging & Security** | | |
|
||||
| `LOG_LEVEL` | Log level: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
|
||||
| `ERROR_LOG_RECORD_REQUEST_BODY` | Record request body in error logs (may contain sensitive information) | `false` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | Auto-delete error logs | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | Error log retention period (days) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Auto-delete request logs | `false` |
|
||||
@@ -211,9 +218,16 @@ app/
|
||||
| **Image Generation** | | |
|
||||
| `PAID_KEY` | Paid API Key for advanced features | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | Image generation model | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | Image upload provider: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `UPLOAD_PROVIDER` | Image upload provider: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
|
||||
| `OSS_ENDPOINT` | Aliyun OSS public endpoint | `oss-cn-shanghai.aliyuncs.com` |
|
||||
| `OSS_ENDPOINT_INNER` | Aliyun OSS internal endpoint (intra-VPC) | `oss-cn-shanghai-internal.aliyuncs.com` |
|
||||
| `OSS_ACCESS_KEY` | Aliyun AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
|
||||
| `OSS_ACCESS_KEY_SECRET` | Aliyun AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
|
||||
| `OSS_BUCKET_NAME` | Aliyun OSS bucket name | `your-bucket-name` |
|
||||
| `OSS_REGION` | Aliyun OSS region | `cn-shanghai` |
|
||||
| `SMMS_SECRET_TOKEN` | SM.MS API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | PicoGo API Key | `your-picogo-apikey` |
|
||||
| `PICGO_API_URL` | PicoGo API Server URL | `https://www.picgo.net/api/1/upload` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | CloudFlare ImgBed upload URL | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare ImgBed auth key | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare ImgBed upload folder | `""` |
|
||||
|
||||
22
README_ZH.md
22
README_ZH.md
@@ -138,6 +138,8 @@ app/
|
||||
|
||||
### Gemini API 格式 (`/gemini/v1beta`)
|
||||
|
||||
此端点将请求直接转发到官方 Gemini API 格式的端点,不包含高级功能。
|
||||
|
||||
* `GET /models`: 列出可用的 Gemini 模型。
|
||||
* `POST /models/{model_name}:generateContent`: 生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 流式生成内容。
|
||||
@@ -146,6 +148,8 @@ app/
|
||||
|
||||
#### 兼容 huggingface (HF) 格式
|
||||
|
||||
如果您需要使用高级功能(例如假流式输出),请使用此端点。
|
||||
|
||||
* `GET /hf/v1/models`: 列出模型。
|
||||
* `POST /hf/v1/chat/completions`: 聊天补全。
|
||||
* `POST /hf/v1/embeddings`: 创建文本嵌入。
|
||||
@@ -153,6 +157,8 @@ app/
|
||||
|
||||
#### 标准 OpenAI 格式
|
||||
|
||||
此端点直接转发至官方的 OpenAI 兼容 API 格式端点,不包含高级功能。
|
||||
|
||||
* `GET /openai/v1/models`: 列出模型。
|
||||
* `POST /openai/v1/chat/completions`: 聊天补全 (推荐,速度更快,防截断)。
|
||||
* `POST /openai/v1/embeddings`: 创建文本嵌入。
|
||||
@@ -178,9 +184,9 @@ app/
|
||||
| `API_KEYS` | **必填**, Gemini API 密钥列表,用于负载均衡 | `[]` |
|
||||
| `ALLOWED_TOKENS` | **必填**, 允许访问的 Token 列表 | `[]` |
|
||||
| `AUTH_TOKEN` | 超级管理员 Token,不填则使用 `ALLOWED_TOKENS` 的第一个 | `sk-123456` |
|
||||
| `TEST_MODEL` | 用于测试密钥可用性的模型 | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | 支持绘图功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | 支持搜索功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `TEST_MODEL` | 用于测试密钥可用性的模型 | `gemini-2.5-flash-lite` |
|
||||
| `IMAGE_MODELS` | 支持绘图功能的模型列表 | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
|
||||
| `SEARCH_MODELS` | 支持搜索功能的模型列表 | `["gemini-2.5-flash","gemini-2.5-pro"]` |
|
||||
| `FILTERED_MODELS` | 被禁用的模型列表 | `[]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | 是否启用代码执行工具 | `false` |
|
||||
| `SHOW_SEARCH_LINK` | 是否在响应中显示搜索结果链接 | `true` |
|
||||
@@ -199,6 +205,7 @@ app/
|
||||
| `PROXIES` | 代理服务器列表 (例如 `http://user:pass@host:port`) | `[]` |
|
||||
| **日志与安全** | | |
|
||||
| `LOG_LEVEL` | 日志级别: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
|
||||
| `ERROR_LOG_RECORD_REQUEST_BODY` | 是否记录错误日志的请求体(可能包含敏感信息) | `false` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 是否自动删除错误日志 | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 错误日志保留天数 | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 是否自动删除请求日志 | `false` |
|
||||
@@ -211,9 +218,16 @@ app/
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 图片生成模型 | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | 图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `UPLOAD_PROVIDER` | 图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
|
||||
| `OSS_ENDPOINT` | 阿里云 OSS 公网 Endpoint | `oss-cn-shanghai.aliyuncs.com` |
|
||||
| `OSS_ENDPOINT_INNER` | 阿里云 OSS 内网 Endpoint(同 VPC 内网访问) | `oss-cn-shanghai-internal.aliyuncs.com` |
|
||||
| `OSS_ACCESS_KEY` | 阿里云 AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
|
||||
| `OSS_ACCESS_KEY_SECRET` | 阿里云 AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
|
||||
| `OSS_BUCKET_NAME` | 阿里云 OSS Bucket 名称 | `your-bucket-name` |
|
||||
| `OSS_REGION` | 阿里云 OSS 区域 Region | `cn-shanghai` |
|
||||
| `SMMS_SECRET_TOKEN` | SM.MS图床的API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | [PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `PICGO_API_URL` | [PicoGo](https://www.picgo.net/)图床的API服务器地址 | `https://www.picgo.net/api/1/upload` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare图床的上传文件夹路径 | `""` |
|
||||
|
||||
@@ -6,7 +6,7 @@ import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List, Type, get_args, get_origin
|
||||
|
||||
from pydantic import ValidationError, ValidationInfo, field_validator, Field
|
||||
from pydantic import Field, ValidationError, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from sqlalchemy import insert, select, update
|
||||
|
||||
@@ -51,8 +51,8 @@ class Settings(BaseSettings):
|
||||
return v
|
||||
|
||||
# API相关配置
|
||||
API_KEYS: List[str]=[]
|
||||
ALLOWED_TOKENS: List[str]=[]
|
||||
API_KEYS: List[str] = []
|
||||
ALLOWED_TOKENS: List[str] = []
|
||||
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
@@ -62,7 +62,9 @@ class Settings(BaseSettings):
|
||||
PROXIES: List[str] = []
|
||||
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
|
||||
VERTEX_API_KEYS: List[str] = []
|
||||
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
||||
VERTEX_EXPRESS_BASE_URL: str = (
|
||||
"https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
||||
)
|
||||
|
||||
# 智能路由配置
|
||||
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
|
||||
@@ -71,13 +73,19 @@ class Settings(BaseSettings):
|
||||
CUSTOM_HEADERS: Dict[str, str] = {}
|
||||
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
|
||||
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
# 是否启用网址上下文
|
||||
URL_CONTEXT_ENABLED: bool = False
|
||||
URL_CONTEXT_MODELS: List[str] = ["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
||||
URL_CONTEXT_MODELS: List[str] = [
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-live-001",
|
||||
]
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
THINKING_MODELS: List[str] = []
|
||||
@@ -94,9 +102,17 @@ class Settings(BaseSettings):
|
||||
UPLOAD_PROVIDER: str = "smms"
|
||||
SMMS_SECRET_TOKEN: str = ""
|
||||
PICGO_API_KEY: str = ""
|
||||
PICGO_API_URL: str = "https://www.picgo.net/api/1/upload"
|
||||
CLOUDFLARE_IMGBED_URL: str = ""
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
|
||||
# 阿里云OSS配置
|
||||
OSS_ENDPOINT: str = ""
|
||||
OSS_ENDPOINT_INNER: str = ""
|
||||
OSS_ACCESS_KEY: str = ""
|
||||
OSS_ACCESS_KEY_SECRET: str = ""
|
||||
OSS_BUCKET_NAME: str = ""
|
||||
OSS_REGION: str = ""
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_OPTIMIZER_ENABLED: bool = False
|
||||
@@ -120,6 +136,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = "INFO"
|
||||
ERROR_LOG_RECORD_REQUEST_BODY: bool = False
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
|
||||
@@ -136,7 +153,7 @@ class Settings(BaseSettings):
|
||||
default=3600,
|
||||
ge=300,
|
||||
le=86400,
|
||||
description="Admin session expiration time in seconds (5 minutes to 24 hours)"
|
||||
description="Admin session expiration time in seconds (5 minutes to 24 hours)",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -168,7 +185,9 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
return [
|
||||
item.strip() for item in db_value.split(",") if item.strip()
|
||||
]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
@@ -220,7 +239,9 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict.")
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 Dict[str, float]
|
||||
elif args and args == (str, float):
|
||||
@@ -242,7 +263,9 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
parsed_dict = {
|
||||
str(k): float(v) for k, v in parsed.items()
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
@@ -403,9 +426,7 @@ async def sync_initial_settings():
|
||||
|
||||
# 序列化值为字符串或 JSON 字符串
|
||||
if isinstance(value, (list, dict)):
|
||||
db_value = json.dumps(
|
||||
value, ensure_ascii=False
|
||||
)
|
||||
db_value = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, bool):
|
||||
db_value = str(value).lower()
|
||||
elif value is None:
|
||||
|
||||
@@ -9,7 +9,7 @@ MAX_RETRIES = 3 # 最大重试次数
|
||||
|
||||
# 模型相关常量
|
||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||
DEFAULT_MODEL = "gemini-1.5-flash"
|
||||
DEFAULT_MODEL = "gemini-2.5-flash-lite"
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
@@ -27,7 +27,7 @@ DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||
|
||||
# 上传提供商
|
||||
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
||||
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed", "aliyun_oss"]
|
||||
DEFAULT_UPLOAD_PROVIDER = "smms"
|
||||
|
||||
# 流式输出相关常量
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
数据库服务模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -122,16 +123,19 @@ async def add_error_log(
|
||||
bool: 是否添加成功
|
||||
"""
|
||||
try:
|
||||
# 如果request_msg是字典,则转换为JSON字符串
|
||||
if isinstance(request_msg, dict):
|
||||
request_msg_json = request_msg
|
||||
elif isinstance(request_msg, str):
|
||||
try:
|
||||
request_msg_json = json.loads(request_msg)
|
||||
except json.JSONDecodeError:
|
||||
request_msg_json = {"message": request_msg}
|
||||
else:
|
||||
if request_msg is None:
|
||||
request_msg_json = None
|
||||
else:
|
||||
# 如果request_msg是字典,则转换为JSON字符串
|
||||
if isinstance(request_msg, dict):
|
||||
request_msg_json = request_msg
|
||||
elif isinstance(request_msg, str):
|
||||
try:
|
||||
request_msg_json = json.loads(request_msg)
|
||||
except json.JSONDecodeError:
|
||||
request_msg_json = {"message": request_msg}
|
||||
else:
|
||||
request_msg_json = None
|
||||
|
||||
# 插入错误日志
|
||||
query = insert(ErrorLog).values(
|
||||
@@ -446,24 +450,50 @@ async def delete_error_log_by_id(log_id: int) -> bool:
|
||||
|
||||
async def delete_all_error_logs() -> int:
|
||||
"""
|
||||
删除所有错误日志条目。
|
||||
分批删除所有错误日志,以避免大数据量下的超时和性能问题。
|
||||
|
||||
Returns:
|
||||
int: 被删除的错误日志数量。如果使用的数据库驱动不支持返回受影响行数,则返回 -1 表示操作成功。
|
||||
int: 被删除的错误日志总数。
|
||||
"""
|
||||
total_deleted_count = 0
|
||||
# SQLite 对 SQL 参数数量有上限(常见为 999),IN 子句中过多参数会报错
|
||||
# 统一使用 500,兼容 SQLite/MySQL,必要时可在配置中暴露该值
|
||||
batch_size = 200
|
||||
|
||||
try:
|
||||
# 直接执行删除操作,避免不必要的查询
|
||||
delete_query = delete(ErrorLog)
|
||||
await database.execute(delete_query)
|
||||
while True:
|
||||
# 1) 读取一批待删除的ID,仅选择ID列以提升效率
|
||||
id_query = select(ErrorLog.id).order_by(ErrorLog.id).limit(batch_size)
|
||||
rows = await database.fetch_all(id_query)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
logger.info("Successfully deleted all error logs.")
|
||||
ids = [row["id"] for row in rows]
|
||||
|
||||
# 由于 databases 库的 execute 方法不返回受影响的行数,
|
||||
# 返回 -1 表示删除操作成功执行,但具体删除数量未知
|
||||
# 这比先查询再删除的方式更高效
|
||||
return -1
|
||||
# 2) 按ID批量删除
|
||||
delete_query = delete(ErrorLog).where(ErrorLog.id.in_(ids))
|
||||
await database.execute(delete_query)
|
||||
|
||||
deleted_in_batch = len(ids)
|
||||
total_deleted_count += deleted_in_batch
|
||||
|
||||
logger.debug(f"Deleted a batch of {deleted_in_batch} error logs.")
|
||||
|
||||
# 若不足一个批次,说明已删除完成
|
||||
if deleted_in_batch < batch_size:
|
||||
break
|
||||
|
||||
# 3) 将控制权交还事件循环,缓解长时间占用
|
||||
await asyncio.sleep(0)
|
||||
|
||||
logger.info(
|
||||
f"Successfully deleted all error logs in batches. Total deleted: {total_deleted_count}"
|
||||
)
|
||||
return total_deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to delete all error logs in batches: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -131,10 +131,5 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
||||
logger.exception(f"Unhandled Exception: {str(exc)}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": "internal_server_error",
|
||||
"message": "An unexpected error occurred",
|
||||
}
|
||||
},
|
||||
content=str(exc),
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ class MessageConverter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
self, messages: List[Dict[str, Any]], model: str
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
@@ -84,7 +84,7 @@ def _convert_image_to_base64(url: str) -> str:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
|
||||
|
||||
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
def _process_text_with_image(text: str, model: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理可能包含图片URL的文本,提取图片并转换为base64
|
||||
|
||||
@@ -94,17 +94,31 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
||||
"""
|
||||
# 如果模型名中没有包含image,当作普通文本处理
|
||||
if "image" not in model:
|
||||
return [{"text": text}]
|
||||
parts = []
|
||||
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
||||
if img_url_match:
|
||||
# 提取URL
|
||||
img_url = img_url_match.group(2)
|
||||
# 将URL对应的图片转换为base64
|
||||
# 先判断是否是base64url如果是,直接用,不过不是,再将URL对应的图片转换为base64
|
||||
try:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
base64_url_match = re.search(DATA_URL_PATTERN, img_url)
|
||||
if base64_url_match:
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": base64_url_match.group(1),
|
||||
"data": base64_url_match.group(2),
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
except Exception:
|
||||
# 如果转换失败,回退到文本模式
|
||||
parts.append({"text": text})
|
||||
@@ -145,7 +159,7 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
raise
|
||||
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
self, messages: List[Dict[str, Any]], model: str
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction_parts = []
|
||||
@@ -296,7 +310,7 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
elif (
|
||||
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
||||
):
|
||||
parts.extend(_process_text_with_image(msg["content"]))
|
||||
parts.extend(_process_text_with_image(msg["content"], model))
|
||||
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
# Keep existing tool call processing
|
||||
for tool_call in msg["tool_calls"]:
|
||||
|
||||
@@ -8,8 +8,9 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.utils.helpers import is_image_upload_configured
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
@@ -32,7 +33,11 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
usage_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if stream:
|
||||
return _handle_gemini_stream_response(response, model, stream)
|
||||
@@ -40,7 +45,10 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
|
||||
|
||||
def _handle_openai_stream_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
finish_reason: str,
|
||||
usage_metadata: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
choices = []
|
||||
candidates = response.get("candidates", [])
|
||||
@@ -54,15 +62,15 @@ def _handle_openai_stream_response(
|
||||
if not text and not tool_calls and not reasoning_content:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"}
|
||||
delta = {
|
||||
"content": text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"role": "assistant",
|
||||
}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
|
||||
choice = {
|
||||
"index": index,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
|
||||
choice = {"index": index, "delta": delta, "finish_reason": finish_reason}
|
||||
choices.append(choice)
|
||||
|
||||
template_chunk = {
|
||||
@@ -73,16 +81,23 @@ def _handle_openai_stream_response(
|
||||
"choices": choices,
|
||||
}
|
||||
if usage_metadata:
|
||||
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
|
||||
template_chunk["usage"] = {
|
||||
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
||||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||||
}
|
||||
return template_chunk
|
||||
|
||||
|
||||
def _handle_openai_normal_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
finish_reason: str,
|
||||
usage_metadata: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
choices = []
|
||||
candidates = response.get("candidates", [])
|
||||
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
{"candidates": [candidate]}, model, stream=False, gemini_format=False
|
||||
@@ -105,7 +120,11 @@ def _handle_openai_normal_response(
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
|
||||
"usage": {
|
||||
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
||||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -126,8 +145,12 @@ class OpenAIResponseHandler(ResponseHandler):
|
||||
usage_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
|
||||
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
|
||||
return _handle_openai_stream_response(
|
||||
response, model, finish_reason, usage_metadata
|
||||
)
|
||||
return _handle_openai_normal_response(
|
||||
response, model, finish_reason, usage_metadata
|
||||
)
|
||||
|
||||
def handle_image_chat_response(
|
||||
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
||||
@@ -181,7 +204,7 @@ def _extract_result(
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
|
||||
text, reasoning_content, tool_calls, thought = "", "", [], None
|
||||
|
||||
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
@@ -190,7 +213,7 @@ def _extract_result(
|
||||
if not parts:
|
||||
logger.warning("No parts found in stream response")
|
||||
return "", None, [], None
|
||||
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
if "thought" in parts[0]:
|
||||
@@ -216,13 +239,13 @@ def _extract_result(
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
text, reasoning_content = "", ""
|
||||
|
||||
|
||||
# 使用安全的访问方式
|
||||
content = candidate.get("content", {})
|
||||
|
||||
|
||||
if content and isinstance(content, dict):
|
||||
parts = content.get("parts", [])
|
||||
|
||||
|
||||
if parts:
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
@@ -240,17 +263,28 @@ def _extract_result(
|
||||
logger.error(f"Invalid content structure for model: {model}")
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
|
||||
|
||||
# 安全地获取 parts 用于工具调用提取
|
||||
parts = candidate.get("content", {}).get("parts", [])
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
logger.warning(f"No candidates found in response for model: {model}")
|
||||
text = "暂无返回"
|
||||
|
||||
|
||||
return text, reasoning_content, tool_calls, thought
|
||||
|
||||
|
||||
def _has_inline_image_part(response: Dict[str, Any]) -> bool:
|
||||
try:
|
||||
for c in response.get("candidates", []):
|
||||
for p in c.get("content", {}).get("parts", []):
|
||||
if isinstance(p, dict) and ("inlineData" in p):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
image_uploader = None
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
@@ -259,7 +293,9 @@ def _extract_image_data(part: dict) -> str:
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.PICGO_API_KEY,
|
||||
api_url=settings.PICGO_API_URL
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
@@ -268,16 +304,30 @@ def _extract_image_data(part: dict) -> str:
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
access_key=settings.OSS_ACCESS_KEY,
|
||||
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
|
||||
bucket_name=settings.OSS_BUCKET_NAME,
|
||||
endpoint=settings.OSS_ENDPOINT,
|
||||
region=settings.OSS_REGION,
|
||||
use_internal=False
|
||||
)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
base64_data = part["inlineData"]["data"]
|
||||
mime_type = part["inlineData"]["mimeType"]
|
||||
# 将base64_data转成bytes数组
|
||||
# Return empty string if no uploader is configured
|
||||
if not is_image_upload_configured(settings):
|
||||
return f"\n\n\n\n"
|
||||
bytes_data = base64.b64decode(base64_data)
|
||||
upload_response = image_uploader.upload(bytes_data, filename)
|
||||
if upload_response.success:
|
||||
text = f"\n\n\n\n"
|
||||
else:
|
||||
text = ""
|
||||
text = f"\n\n\n\n"
|
||||
return text
|
||||
|
||||
|
||||
@@ -290,7 +340,7 @@ def _extract_tool_calls(
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
tool_calls = list()
|
||||
|
||||
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
@@ -299,7 +349,7 @@ def _extract_tool_calls(
|
||||
item = part.get("functionCall", {})
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
|
||||
if gemini_format:
|
||||
tool_calls.append(part)
|
||||
else:
|
||||
@@ -322,6 +372,10 @@ def _extract_tool_calls(
|
||||
def _handle_gemini_stream_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
# Early return raw Gemini response if no uploader configured and contains inline images
|
||||
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
|
||||
return response
|
||||
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
@@ -339,6 +393,10 @@ def _handle_gemini_stream_response(
|
||||
def _handle_gemini_normal_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
# Early return raw Gemini response if no uploader configured and contains inline images
|
||||
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
|
||||
return response
|
||||
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
@@ -347,7 +405,7 @@ def _handle_gemini_normal_response(
|
||||
parts = tool_calls
|
||||
else:
|
||||
if thought is not None:
|
||||
parts.append({"text": reasoning_content,"thought": thought})
|
||||
parts.append({"text": reasoning_content, "thought": thought})
|
||||
part = {"text": text}
|
||||
parts.append(part)
|
||||
content = {"parts": parts, "role": "model"}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
import re
|
||||
import sys
|
||||
from typing import Dict, Optional
|
||||
from app.utils.helpers import redact_key_for_logging as _redact_key_for_logging
|
||||
|
||||
# ANSI转义序列颜色代码
|
||||
COLORS = {
|
||||
@@ -15,7 +14,6 @@ COLORS = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == "Windows":
|
||||
import ctypes
|
||||
@@ -46,14 +44,16 @@ class AccessLogFormatter(logging.Formatter):
|
||||
|
||||
# API key patterns to match in URLs
|
||||
API_KEY_PATTERNS = [
|
||||
r'\bAIza[0-9A-Za-z_-]{35}', # Google API keys (like Gemini)
|
||||
r'\bsk-[0-9A-Za-z_-]{20,}', # OpenAI and general sk- prefixed keys
|
||||
r"\bAIza[0-9A-Za-z_-]{35}", # Google API keys (like Gemini)
|
||||
r"\bsk-[0-9A-Za-z_-]{20,}", # OpenAI and general sk- prefixed keys
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Compile regex patterns for better performance
|
||||
self.compiled_patterns = [re.compile(pattern) for pattern in self.API_KEY_PATTERNS]
|
||||
self.compiled_patterns = [
|
||||
re.compile(pattern) for pattern in self.API_KEY_PATTERNS
|
||||
]
|
||||
|
||||
def format(self, record):
|
||||
# Format the record normally first
|
||||
@@ -68,9 +68,10 @@ class AccessLogFormatter(logging.Formatter):
|
||||
"""
|
||||
try:
|
||||
for pattern in self.compiled_patterns:
|
||||
|
||||
def replace_key(match):
|
||||
key = match.group(0)
|
||||
return _redact_key_for_logging(key)
|
||||
return redact_key_for_logging(key)
|
||||
|
||||
message = pattern.sub(replace_key, message)
|
||||
|
||||
@@ -78,11 +79,31 @@ class AccessLogFormatter(logging.Formatter):
|
||||
except Exception as e:
|
||||
# Log the error but don't expose the original message in case it contains keys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error redacting API keys in access log: {e}")
|
||||
return "[LOG_REDACTION_ERROR]"
|
||||
|
||||
|
||||
def redact_key_for_logging(key: str) -> str:
|
||||
"""
|
||||
Redacts API key for secure logging by showing only first and last 6 characters.
|
||||
|
||||
Args:
|
||||
key: API key to redact
|
||||
|
||||
Returns:
|
||||
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
|
||||
"""
|
||||
if not key:
|
||||
return key
|
||||
|
||||
if len(key) <= 12:
|
||||
return f"{key[:3]}...{key[-3:]}"
|
||||
else:
|
||||
return f"{key[:6]}...{key[-6:]}"
|
||||
|
||||
|
||||
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
||||
FORMATTER = ColoredFormatter(
|
||||
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
|
||||
@@ -326,4 +347,3 @@ def setup_access_logging():
|
||||
access_logger.propagate = False
|
||||
|
||||
return access_logger
|
||||
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest, GeminiEmbedRequest, GeminiBatchEmbedRequest
|
||||
from app.domain.gemini_models import (
|
||||
GeminiBatchEmbedRequest,
|
||||
GeminiContent,
|
||||
GeminiEmbedRequest,
|
||||
GeminiRequest,
|
||||
ResetSelectedKeysRequest,
|
||||
VerifySelectedKeysRequest,
|
||||
)
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||
@@ -47,8 +56,8 @@ async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manage
|
||||
@router.get("/models")
|
||||
@router_v1beta.get("/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
@@ -58,20 +67,30 @@ async def list_models(
|
||||
try:
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="No valid API keys available to fetch models."
|
||||
)
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to fetch base models list."
|
||||
)
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
model_mapping = {
|
||||
x.get("name", "").split("/", maxsplit=1)[-1]: x
|
||||
for x in models_json.get("models", [])
|
||||
}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
logger.warning(
|
||||
f"Base model '{base_name}' not found for derived model '{suffix}'."
|
||||
)
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
@@ -85,7 +104,7 @@ async def list_models(
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
@@ -97,7 +116,8 @@ async def list_models(
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching Gemini models list",
|
||||
) from e
|
||||
|
||||
|
||||
@@ -107,15 +127,19 @@ async def list_models(
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Content generation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini content generation request for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
# 检测是否为原生Gemini TTS请求
|
||||
@@ -132,10 +156,13 @@ async def generate_content(
|
||||
logger.info(f"TTS responseModalities: {response_modalities}")
|
||||
logger.info(f"TTS speechConfig: {speech_config}")
|
||||
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
# 所有原生TTS请求都使用TTS增强服务
|
||||
if is_native_tts:
|
||||
@@ -143,19 +170,17 @@ async def generate_content(
|
||||
logger.info("Using native TTS enhanced service")
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"Native TTS processing failed, falling back to standard service: {e}")
|
||||
logger.warning(
|
||||
f"Native TTS processing failed, falling back to standard service: {e}"
|
||||
)
|
||||
|
||||
# 使用标准服务处理所有其他请求(非TTS)
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -166,27 +191,53 @@ async def generate_content(
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Streaming request initiation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini streaming content generation for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
raw_stream = chat_service.stream_generate_content(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_stream:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:countTokens")
|
||||
@@ -195,53 +246,60 @@ async def stream_generate_content(
|
||||
async def count_tokens(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini token 计数请求。"""
|
||||
operation_name = "gemini_count_tokens"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"):
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Token counting failed"
|
||||
):
|
||||
logger.info(f"Handling Gemini token count request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await chat_service.count_tokens(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:embedContent")
|
||||
@router_v1beta.post("/models/{model_name}:embedContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def embed_content(
|
||||
model_name: str,
|
||||
request: GeminiEmbedRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
|
||||
):
|
||||
"""处理 Gemini 单一嵌入请求"""
|
||||
operation_name = "gemini_embed_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Embedding content generation failed"):
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Embedding content generation failed"
|
||||
):
|
||||
logger.info(f"Handling Gemini embedding request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await embedding_service.embed_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -252,41 +310,48 @@ async def embed_content(
|
||||
async def batch_embed_contents(
|
||||
model_name: str,
|
||||
request: GeminiBatchEmbedRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
|
||||
):
|
||||
"""处理 Gemini 批量嵌入请求"""
|
||||
operation_name = "gemini_batch_embed_contents"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Batch embedding content generation failed"):
|
||||
async with handle_route_errors(
|
||||
logger,
|
||||
operation_name,
|
||||
failure_message="Batch embedding content generation failed",
|
||||
):
|
||||
logger.info(f"Handling Gemini batch embedding request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await embedding_service.batch_embed_contents(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
async def reset_all_key_fail_counts(
|
||||
key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
||||
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
|
||||
logger.info(f"Received reset request with key_type: {key_type}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取分类后的密钥
|
||||
keys_by_status = await key_manager.get_keys_by_status()
|
||||
valid_keys = keys_by_status.get("valid_keys", {})
|
||||
invalid_keys = keys_by_status.get("invalid_keys", {})
|
||||
|
||||
|
||||
# 根据类型选择要重置的密钥
|
||||
keys_to_reset = []
|
||||
if key_type == "valid":
|
||||
@@ -298,35 +363,45 @@ async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManage
|
||||
else:
|
||||
# 重置所有密钥
|
||||
await key_manager.reset_failure_counts()
|
||||
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
|
||||
|
||||
return JSONResponse(
|
||||
{"success": True, "message": "所有密钥的失败计数已重置"}
|
||||
)
|
||||
|
||||
# 批量重置指定类型的密钥
|
||||
for key in keys_to_reset:
|
||||
await key_manager.reset_key_failure_count(key)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"{key_type}密钥的失败计数已重置",
|
||||
"reset_count": len(keys_to_reset)
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"{key_type}密钥的失败计数已重置",
|
||||
"reset_count": len(keys_to_reset),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset key failure counts: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
return JSONResponse(
|
||||
{"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-selected-fail-counts")
|
||||
async def reset_selected_key_fail_counts(
|
||||
request: ResetSelectedKeysRequest,
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""批量重置选定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
||||
keys_to_reset = request.keys
|
||||
key_type = request.key_type
|
||||
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
||||
logger.info(
|
||||
f"Received reset request for {len(keys_to_reset)} selected {key_type} keys."
|
||||
)
|
||||
|
||||
if not keys_to_reset:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": "没有提供需要重置的密钥"}, status_code=400
|
||||
)
|
||||
|
||||
reset_count = 0
|
||||
errors = []
|
||||
@@ -338,53 +413,79 @@ async def reset_selected_key_fail_counts(
|
||||
if result:
|
||||
reset_count += 1
|
||||
else:
|
||||
logger.warning(f"Key not found during selective reset: {redact_key_for_logging(key)}")
|
||||
logger.warning(
|
||||
f"Key not found during selective reset: {redact_key_for_logging(key)}"
|
||||
)
|
||||
except Exception as key_error:
|
||||
logger.error(f"Error resetting key {redact_key_for_logging(key)}: {str(key_error)}")
|
||||
logger.error(
|
||||
f"Error resetting key {redact_key_for_logging(key)}: {str(key_error)}"
|
||||
)
|
||||
errors.append(f"Key {key}: {str(key_error)}")
|
||||
|
||||
if errors:
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse({
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count
|
||||
}, status_code=status_code)
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count,
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
||||
logger.error(
|
||||
f"Failed to process reset selected key failure counts request: {str(e)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": f"批量重置处理失败: {str(e)}"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-fail-count/{api_key}")
|
||||
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
async def reset_key_fail_count(
|
||||
api_key: str, key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""重置指定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
|
||||
logger.info(f"Resetting failure count for API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
logger.info(
|
||||
f"Resetting failure count for API key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await key_manager.reset_key_failure_count(api_key)
|
||||
if result:
|
||||
return JSONResponse({"success": True, "message": "失败计数已重置"})
|
||||
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": "未找到指定密钥"}, status_code=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset key failure count: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": f"重置失败: {str(e)}"}, status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/verify-key/{api_key}")
|
||||
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
|
||||
async def verify_key(
|
||||
api_key: str,
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""验证Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||||
logger.info("Verifying API key validity")
|
||||
|
||||
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
@@ -393,43 +494,47 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
parts=[{"text": "hi"}],
|
||||
)
|
||||
],
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10},
|
||||
)
|
||||
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
settings.TEST_MODEL, gemini_request, api_key
|
||||
)
|
||||
|
||||
|
||||
if response:
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return JSONResponse({"status": "valid"})
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
|
||||
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count")
|
||||
|
||||
return JSONResponse({"status": "invalid", "error": str(e)})
|
||||
logger.warning(
|
||||
f"Verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
|
||||
)
|
||||
|
||||
return JSONResponse({"status": "invalid", "error": e.args[1]})
|
||||
|
||||
|
||||
@router.post("/verify-selected-keys")
|
||||
async def verify_selected_keys(
|
||||
request: VerifySelectedKeysRequest,
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""批量验证选定Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
|
||||
keys_to_verify = request.keys
|
||||
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
|
||||
logger.info(
|
||||
f"Received verification request for {len(keys_to_verify)} selected keys."
|
||||
)
|
||||
|
||||
if not keys_to_verify:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": "没有提供需要验证的密钥"}, status_code=400
|
||||
)
|
||||
|
||||
successful_keys = []
|
||||
failed_keys = {}
|
||||
@@ -440,28 +545,36 @@ async def verify_selected_keys(
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
|
||||
generation_config={
|
||||
"temperature": 0.7,
|
||||
"topP": 1.0,
|
||||
"maxOutputTokens": 10,
|
||||
},
|
||||
)
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
settings.TEST_MODEL, gemini_request, api_key
|
||||
)
|
||||
successful_keys.append(api_key)
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return api_key, "valid", None
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.warning(f"Key verification failed for {redact_key_for_logging(api_key)}: {error_message}")
|
||||
error_message = e.args[1]
|
||||
logger.warning(
|
||||
f"Key verification failed for {redact_key_for_logging(api_key)}: {error_message}"
|
||||
)
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count")
|
||||
logger.warning(
|
||||
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
|
||||
)
|
||||
else:
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, initializing failure count to 1")
|
||||
failed_keys[api_key] = error_message
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(
|
||||
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, initializing failure count to 1"
|
||||
)
|
||||
failed_keys[api_key] = {"error_message": e.args[1], "error_code": e.args[0]}
|
||||
return api_key, "invalid", error_message
|
||||
|
||||
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
||||
@@ -469,34 +582,37 @@ async def verify_selected_keys(
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
||||
elif result:
|
||||
if not isinstance(result, Exception) and result:
|
||||
key, status, error = result
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(f"Task execution error during bulk verification: {result}")
|
||||
logger.error(
|
||||
f"An unexpected error occurred during bulk verification task: {result}"
|
||||
)
|
||||
|
||||
valid_count = len(successful_keys)
|
||||
invalid_count = len(failed_keys)
|
||||
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
||||
logger.info(
|
||||
f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}"
|
||||
)
|
||||
|
||||
if failed_keys:
|
||||
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count,
|
||||
}
|
||||
)
|
||||
else:
|
||||
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": {},
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": 0
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": {},
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
@@ -8,19 +8,21 @@ from app.domain.openai_models import (
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
|
||||
from app.service.openai_compatiable.openai_compatiable_service import (
|
||||
OpenAICompatiableService,
|
||||
)
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
@@ -38,7 +40,7 @@ async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager))
|
||||
|
||||
@router.get("/openai/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
@@ -47,6 +49,7 @@ async def list_models(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
return await openai_service.get_models(api_key)
|
||||
|
||||
@@ -55,7 +58,7 @@ async def list_models(
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
@@ -70,28 +73,56 @@ async def chat_completion(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
|
||||
|
||||
raw_response = None
|
||||
if is_image_chat:
|
||||
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
||||
return response
|
||||
raw_response = await openai_service.create_image_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
else:
|
||||
response = await openai_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
raw_response = await openai_service.create_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
if request.stream:
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_response.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_response, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_response:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
|
||||
@router.post("/openai/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
request.model = settings.CREATE_IMAGE_MODEL
|
||||
return await openai_service.generate_images(request)
|
||||
|
||||
@@ -99,7 +130,7 @@ async def generate_image(
|
||||
@router.post("/openai/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
@@ -108,6 +139,7 @@ async def embedding(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
return await openai_service.create_embeddings(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
@@ -9,15 +9,15 @@ from app.domain.openai_models import (
|
||||
ImageGenerationRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.tts.tts_service import TTSService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.service.tts.tts_service import TTSService
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter()
|
||||
@@ -53,7 +53,7 @@ async def get_tts_service():
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
||||
@@ -61,6 +61,7 @@ async def list_models(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
return await model_service.get_gemini_openai_models(api_key)
|
||||
|
||||
@@ -70,7 +71,7 @@ async def list_models(
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
||||
@@ -85,6 +86,7 @@ async def chat_completion(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(request.model):
|
||||
@@ -92,28 +94,54 @@ async def chat_completion(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
raw_response = None
|
||||
if is_image_chat:
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
raw_response = await chat_service.create_image_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
else:
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
raw_response = await chat_service.create_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_response.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_response, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_response:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@router.post("/hf/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
):
|
||||
"""处理 OpenAI 图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
response = image_create_service.generate_images(request)
|
||||
return response
|
||||
|
||||
@@ -122,7 +150,7 @@ async def generate_image(
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""处理 OpenAI 文本嵌入请求。"""
|
||||
@@ -130,6 +158,7 @@ async def embedding(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
@@ -162,7 +191,7 @@ async def get_keys_list(
|
||||
@router.post("/hf/v1/audio/speech")
|
||||
async def text_to_speech(
|
||||
request: TTSRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
):
|
||||
@@ -171,6 +200,7 @@ async def text_to_speech(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling TTS request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
audio_data = await tts_service.create_tts(request, api_key)
|
||||
return Response(content=audio_data, media_type="audio/wav")
|
||||
|
||||
@@ -24,10 +24,13 @@ from app.router import (
|
||||
)
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats.stats_service import StatsService
|
||||
from app.utils.static_version import get_static_url
|
||||
|
||||
logger = get_routes_logger()
|
||||
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
# 设置模板全局变量
|
||||
templates.env.globals["static_url"] = get_static_url
|
||||
|
||||
|
||||
def setup_routers(app: FastAPI) -> None:
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_vertex_express_logger
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_vertex_express_logger
|
||||
from app.service.chat.vertex_express_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
|
||||
@@ -37,8 +39,8 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
@@ -48,20 +50,30 @@ async def list_models(
|
||||
try:
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="No valid API keys available to fetch models."
|
||||
)
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to fetch base models list."
|
||||
)
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
model_mapping = {
|
||||
x.get("name", "").split("/", maxsplit=1)[-1]: x
|
||||
for x in models_json.get("models", [])
|
||||
}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
logger.warning(
|
||||
f"Base model '{base_name}' not found for derived model '{suffix}'."
|
||||
)
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
@@ -75,7 +87,7 @@ async def list_models(
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
@@ -87,7 +99,8 @@ async def list_models(
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching Gemini models list",
|
||||
) from e
|
||||
|
||||
|
||||
@@ -96,25 +109,30 @@ async def list_models(
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Content generation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini content generation request for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -124,24 +142,50 @@ async def generate_content(
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Streaming request initiation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini streaming content generation for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
raw_stream = chat_service.stream_generate_content(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_stream:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from app.config.config import settings
|
||||
@@ -6,9 +5,9 @@ from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||
from app.log.logger import Logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.error_log.error_log_service import delete_old_error_logs
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = Logger.setup_logger("scheduler")
|
||||
@@ -106,15 +105,16 @@ async def cleanup_expired_files():
|
||||
try:
|
||||
files_service = await get_files_service()
|
||||
deleted_count = await files_service.cleanup_expired_files()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
|
||||
else:
|
||||
logger.info("No expired files to clean up.")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True
|
||||
f"An error occurred during the scheduled file cleanup: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -122,44 +122,45 @@ def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
# 添加检查失败密钥的定时任务
|
||||
scheduler.add_job(
|
||||
check_failed_keys,
|
||||
"interval",
|
||||
hours=settings.CHECK_INTERVAL_HOURS,
|
||||
id="check_failed_keys_job",
|
||||
name="Check Failed API Keys",
|
||||
)
|
||||
logger.info(
|
||||
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
||||
)
|
||||
if settings.CHECK_INTERVAL_HOURS != 0:
|
||||
scheduler.add_job(
|
||||
check_failed_keys,
|
||||
"interval",
|
||||
hours=settings.CHECK_INTERVAL_HOURS,
|
||||
id="check_failed_keys_job",
|
||||
name="Check Failed API Keys",
|
||||
)
|
||||
logger.info(
|
||||
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
||||
)
|
||||
|
||||
# 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
|
||||
# 新增:添加自动删除错误日志的定时任务,每天凌晨0点执行
|
||||
scheduler.add_job(
|
||||
delete_old_error_logs,
|
||||
"cron",
|
||||
hour=3,
|
||||
hour=0,
|
||||
minute=0,
|
||||
id="delete_old_error_logs_job",
|
||||
name="Delete Old Error Logs",
|
||||
)
|
||||
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
|
||||
|
||||
# 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
|
||||
# 新增:添加自动删除请求日志的定时任务,每天凌晨0点执行
|
||||
scheduler.add_job(
|
||||
delete_old_request_logs_task,
|
||||
"cron",
|
||||
hour=3,
|
||||
minute=5,
|
||||
hour=0,
|
||||
minute=0,
|
||||
id="delete_old_request_logs_job",
|
||||
name="Delete Old Request Logs",
|
||||
)
|
||||
logger.info(
|
||||
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
||||
)
|
||||
|
||||
|
||||
# 新增:添加文件过期清理的定时任务,每小时执行一次
|
||||
if getattr(settings, 'FILES_CLEANUP_ENABLED', True):
|
||||
cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1)
|
||||
if getattr(settings, "FILES_CLEANUP_ENABLED", True):
|
||||
cleanup_interval = getattr(settings, "FILES_CLEANUP_INTERVAL_HOURS", 1)
|
||||
scheduler.add_job(
|
||||
cleanup_expired_files,
|
||||
"interval",
|
||||
|
||||
@@ -365,13 +365,9 @@ class GeminiChatService:
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -379,7 +375,7 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
@@ -416,13 +412,9 @@ class GeminiChatService:
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -430,7 +422,7 @@ class GeminiChatService:
|
||||
error_type="gemini-count-tokens",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
@@ -508,15 +500,11 @@ class GeminiChatService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
@@ -524,7 +512,9 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
@@ -537,11 +527,11 @@ class GeminiChatService:
|
||||
)
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
raise
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
@@ -285,7 +284,9 @@ class OpenAIChatService:
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
messages, instruction = self.message_converter.convert(
|
||||
request.messages, request.model
|
||||
)
|
||||
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
@@ -337,7 +338,8 @@ class OpenAIChatService:
|
||||
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"API call failed for model {model}: {error_log_msg}")
|
||||
|
||||
# 特别记录 max_tokens 相关的错误
|
||||
@@ -351,16 +353,13 @@ class OpenAIChatService:
|
||||
if "parts" in error_log_msg:
|
||||
logger.error("This is likely a response processing error")
|
||||
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
status_code = int(match.group(1)) if match else 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
@@ -538,27 +537,21 @@ class OpenAIChatService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
|
||||
)
|
||||
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
status_code = 408
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
@@ -575,7 +568,7 @@ class OpenAIChatService:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries, ceasing attempts for this request."
|
||||
)
|
||||
break
|
||||
raise
|
||||
else:
|
||||
logger.error(
|
||||
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
|
||||
@@ -586,6 +579,7 @@ class OpenAIChatService:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming model {model}."
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -598,13 +592,6 @@ class OpenAIChatService:
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
if not is_success:
|
||||
logger.error(
|
||||
f"Streaming failed permanently for model {model} after {retries} attempts."
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self, request: ChatRequest, api_key: str
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
@@ -663,20 +650,23 @@ class OpenAIChatService:
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
request_msg=(
|
||||
{"image_data_truncated": image_data[:1000]}
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -714,19 +704,23 @@ class OpenAIChatService:
|
||||
return result
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
request_msg=(
|
||||
{"image_data_truncated": image_data[:1000]}
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
@@ -278,13 +277,9 @@ class GeminiChatService:
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -292,7 +287,7 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
@@ -356,15 +351,11 @@ class GeminiChatService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
@@ -372,7 +363,9 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
@@ -385,11 +378,11 @@ class GeminiChatService:
|
||||
)
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
raise
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
# app/services/chat/api_client.py
|
||||
|
||||
from typing import Dict, Any, AsyncGenerator, Optional
|
||||
import httpx
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_api_client_logger
|
||||
from app.core.constants import DEFAULT_TIMEOUT
|
||||
from app.log.logger import get_api_client_logger
|
||||
|
||||
logger = get_api_client_logger()
|
||||
|
||||
|
||||
class ApiClient(ABC):
|
||||
"""API客户端基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
async def generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
async def stream_generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -50,7 +57,7 @@ class GeminiApiClient(ApiClient):
|
||||
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取可用的 Gemini 模型列表"""
|
||||
timeout = httpx.Timeout(timeout=5)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -73,11 +80,13 @@ class GeminiApiClient(ApiClient):
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求模型列表失败: {e}")
|
||||
return None
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
|
||||
async def generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -85,42 +94,33 @@ class GeminiApiClient(ApiClient):
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
|
||||
headers = self._prepare_headers()
|
||||
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应结构的基本信息
|
||||
if not response_data.get("candidates"):
|
||||
logger.warning("No candidates found in API response")
|
||||
|
||||
return response_data
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(
|
||||
f"API call failed - Status: {response.status_code}, Content: {error_content}"
|
||||
)
|
||||
raise Exception(response.status_code, error_content)
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应结构的基本信息
|
||||
if not response_data.get("candidates"):
|
||||
logger.warning("No candidates found in API response")
|
||||
|
||||
return response_data
|
||||
|
||||
async def stream_generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -132,15 +132,19 @@ class GeminiApiClient(ApiClient):
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
async with client.stream(
|
||||
method="POST", url=url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
raise Exception(response.status_code, error_msg)
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
async def count_tokens(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
@@ -158,14 +162,16 @@ class GeminiApiClient(ApiClient):
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def embed_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
async def embed_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""单一嵌入内容生成"""
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -177,32 +183,22 @@ class GeminiApiClient(ApiClient):
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:embedContent?key={api_key}"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Embedding request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Embedding request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected embedding error: {e}")
|
||||
raise
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(
|
||||
f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}"
|
||||
)
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def batch_embed_contents(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
async def batch_embed_contents(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""批量嵌入内容生成"""
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -214,26 +210,14 @@ class GeminiApiClient(ApiClient):
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:batchEmbedContents?key={api_key}"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Batch embedding request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Batch embedding request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected batch embedding error: {e}")
|
||||
raise
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(
|
||||
f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}"
|
||||
)
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
@@ -242,7 +226,7 @@ class OpenaiApiClient(ApiClient):
|
||||
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def _prepare_headers(self, api_key: str) -> Dict[str, str]:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
@@ -267,12 +251,16 @@ class OpenaiApiClient(ApiClient):
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
async def generate_content(
|
||||
self, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}")
|
||||
logger.info(
|
||||
f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}"
|
||||
)
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -287,10 +275,12 @@ class OpenaiApiClient(ApiClient):
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
|
||||
async def stream_generate_content(
|
||||
self, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
@@ -303,17 +293,21 @@ class OpenaiApiClient(ApiClient):
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
async with client.stream(
|
||||
method="POST", url=url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
raise Exception(response.status_code, error_msg)
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
|
||||
|
||||
async def create_embeddings(
|
||||
self, input: str, model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -332,10 +326,12 @@ class OpenaiApiClient(ApiClient):
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
|
||||
async def generate_images(
|
||||
self, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
@@ -352,5 +348,5 @@ class OpenaiApiClient(ApiClient):
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
import re
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
@@ -56,13 +55,9 @@ class EmbeddingService:
|
||||
raise e
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
status_code = 500
|
||||
error_log_msg = f"Generic error: {e}"
|
||||
logger.error(f"Error creating embedding (Exception): {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", str(e))
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
@@ -74,7 +69,11 @@ class EmbeddingService:
|
||||
error_type="openai-embedding",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request_msg_log,
|
||||
request_msg=(
|
||||
request_msg_log
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
await add_request_log(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# app/service/embedding/gemini_embedding_service.py
|
||||
|
||||
import datetime
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -69,13 +68,9 @@ class GeminiEmbeddingService:
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Single embedding API call failed: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -83,7 +78,7 @@ class GeminiEmbeddingService:
|
||||
error_type="gemini-embed-single",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
@@ -119,13 +114,9 @@ class GeminiEmbeddingService:
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Batch embedding API call failed: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -133,7 +124,7 @@ class GeminiEmbeddingService:
|
||||
error_type="gemini-embed-batch",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
@@ -28,7 +28,7 @@ async def delete_old_error_logs():
|
||||
)
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
logger.info(
|
||||
f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.config.config import settings
|
||||
from app.core.constants import VALID_IMAGE_RATIOS
|
||||
from app.domain.openai_models import ImageGenerationRequest
|
||||
from app.log.logger import get_image_create_logger
|
||||
from app.utils.helpers import is_image_upload_configured
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
logger = get_image_create_logger()
|
||||
@@ -97,12 +98,18 @@ class ImageCreateService:
|
||||
image_data = generated_image.image.image_bytes
|
||||
image_uploader = None
|
||||
|
||||
if request.response_format == "b64_json":
|
||||
# Return base64 if explicitly requested or if no uploader is configured
|
||||
if (
|
||||
request.response_format == "b64_json"
|
||||
or not is_image_upload_configured(settings)
|
||||
):
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
images_data.append(
|
||||
{"b64_json": base64_image, "revised_prompt": request.prompt}
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Upload to configured provider
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
|
||||
@@ -115,6 +122,7 @@ class ImageCreateService:
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.PICGO_API_KEY,
|
||||
api_url=settings.PICGO_API_URL,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
@@ -123,6 +131,16 @@ class ImageCreateService:
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
access_key=settings.OSS_ACCESS_KEY,
|
||||
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
|
||||
bucket_name=settings.OSS_BUCKET_NAME,
|
||||
endpoint=settings.OSS_ENDPOINT,
|
||||
region=settings.OSS_REGION,
|
||||
use_internal=False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, Union
|
||||
|
||||
@@ -80,13 +78,9 @@ class OpenAICompatiableService:
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -94,7 +88,7 @@ class OpenAICompatiableService:
|
||||
error_type="openai-compatiable-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request,
|
||||
request_msg=request if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
@@ -138,15 +132,11 @@ class OpenAICompatiableService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
@@ -154,7 +144,9 @@ class OpenAICompatiableService:
|
||||
error_type="openai-compatiable-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
@@ -170,14 +162,14 @@ class OpenAICompatiableService:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries."
|
||||
)
|
||||
break
|
||||
raise
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -189,6 +181,3 @@ class OpenAICompatiableService:
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
if not is_success and retries >= max_retries:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
Service for request log operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from app.database.connection import database
|
||||
from app.config.config import settings
|
||||
from app.database.connection import database
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import get_request_log_logger
|
||||
|
||||
@@ -30,7 +30,7 @@ async def delete_old_request_logs_task():
|
||||
)
|
||||
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)
|
||||
|
||||
|
||||
@@ -3,14 +3,16 @@
|
||||
继承自原始聊天服务,添加原生Gemini TTS支持(单人和多人),保持向后兼容
|
||||
"""
|
||||
|
||||
import time
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_response_handler import TTSResponseHandler
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.database.services import add_request_log, add_error_log
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_response_handler import TTSResponseHandler
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -28,7 +30,9 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
super().__init__(base_url, key_manager)
|
||||
# 使用TTS响应处理器替换原始处理器
|
||||
self.response_handler = TTSResponseHandler()
|
||||
logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support")
|
||||
logger.info(
|
||||
"TTS Gemini Chat Service initialized with multi-speaker TTS support"
|
||||
)
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
@@ -55,7 +59,9 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
logger.error(f"TTS API call failed with error: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
||||
async def _handle_tts_request(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理TTS特定的请求,包含完整的日志记录功能
|
||||
"""
|
||||
@@ -89,14 +95,24 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
if request.generationConfig:
|
||||
# 添加TTS特定字段
|
||||
if request.generationConfig.responseModalities:
|
||||
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
|
||||
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
|
||||
payload["generationConfig"][
|
||||
"responseModalities"
|
||||
] = request.generationConfig.responseModalities
|
||||
logger.info(
|
||||
f"Added responseModalities: {request.generationConfig.responseModalities}"
|
||||
)
|
||||
|
||||
if request.generationConfig.speechConfig:
|
||||
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
|
||||
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
|
||||
payload["generationConfig"][
|
||||
"speechConfig"
|
||||
] = request.generationConfig.speechConfig
|
||||
logger.info(
|
||||
f"Added speechConfig: {request.generationConfig.speechConfig}"
|
||||
)
|
||||
else:
|
||||
logger.warning("No generationConfig found in request, TTS fields may be missing")
|
||||
logger.warning(
|
||||
"No generationConfig found in request, TTS fields may be missing"
|
||||
)
|
||||
|
||||
logger.info(f"TTS payload before API call: {payload}")
|
||||
|
||||
@@ -117,6 +133,7 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
|
||||
# 尝试从错误消息中提取状态码
|
||||
import re
|
||||
|
||||
match = re.search(r"status code (\d+)", error_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
@@ -130,7 +147,11 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
error_type="tts-api-error",
|
||||
error_log=error_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.model_dump(exclude_none=False)
|
||||
request_msg=(
|
||||
request.model_dump(exclude_none=False)
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
logger.error(f"TTS API call failed: {error_msg}")
|
||||
@@ -147,5 +168,5 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ class TTSService:
|
||||
error_log_msg = ""
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
response =await client.aio.models.generate_content(
|
||||
response = await client.aio.models.generate_content(
|
||||
model=settings.TTS_MODEL,
|
||||
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
|
||||
config={
|
||||
@@ -48,7 +48,11 @@ class TTSService:
|
||||
"speech_config": {
|
||||
"voice_config": {
|
||||
"prebuilt_voice_config": {
|
||||
"voice_name": request.voice if request.voice in TTS_VOICE_NAMES else settings.TTS_VOICE_NAME
|
||||
"voice_name": (
|
||||
request.voice
|
||||
if request.voice in TTS_VOICE_NAMES
|
||||
else settings.TTS_VOICE_NAME
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -59,7 +63,9 @@ class TTSService:
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].inline_data
|
||||
):
|
||||
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
|
||||
raw_audio_data = (
|
||||
response.candidates[0].content.parts[0].inline_data.data
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return _create_wav_file(raw_audio_data)
|
||||
@@ -83,13 +89,17 @@ class TTSService:
|
||||
error_type="google-tts",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.input
|
||||
)
|
||||
request_msg=(
|
||||
request.input
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=settings.TTS_MODEL,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
@@ -104,6 +104,24 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
});
|
||||
}
|
||||
|
||||
// 检查间隔小时数输入控制
|
||||
const checkIntervalInput = document.getElementById("CHECK_INTERVAL_HOURS");
|
||||
if (checkIntervalInput) {
|
||||
checkIntervalInput.addEventListener("input", function () {
|
||||
let value = parseFloat(this.value);
|
||||
if (isNaN(value) || value < 0) {
|
||||
this.value = 0;
|
||||
}
|
||||
});
|
||||
|
||||
checkIntervalInput.addEventListener("change", function () {
|
||||
let value = parseFloat(this.value);
|
||||
if (isNaN(value) || value < 0) {
|
||||
this.value = 0;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Toggle switch events
|
||||
const toggleSwitches = document.querySelectorAll(".toggle-switch");
|
||||
toggleSwitches.forEach((toggleSwitch) => {
|
||||
@@ -771,6 +789,10 @@ async function initConfig() {
|
||||
if (typeof config.AUTO_DELETE_ERROR_LOGS_DAYS === "undefined") {
|
||||
config.AUTO_DELETE_ERROR_LOGS_DAYS = 7;
|
||||
}
|
||||
// 错误日志是否记录请求体(默认不记录)
|
||||
if (typeof config.ERROR_LOG_RECORD_REQUEST_BODY === "undefined") {
|
||||
config.ERROR_LOG_RECORD_REQUEST_BODY = false;
|
||||
}
|
||||
// --- 结束:处理自动删除错误日志配置的默认值 ---
|
||||
|
||||
// --- 新增:处理自动删除请求日志配置的默认值 ---
|
||||
|
||||
@@ -541,30 +541,13 @@ function showVerificationResultModal(data) {
|
||||
const errorGroups = {};
|
||||
Object.entries(failedKeys).forEach(([key, error]) => {
|
||||
// 提取错误码或使用完整错误信息作为分组键
|
||||
let errorCode = error;
|
||||
|
||||
// 尝试提取常见的错误码模式
|
||||
const errorCodePatterns = [
|
||||
/status code (\d+)/,
|
||||
];
|
||||
|
||||
for (const pattern of errorCodePatterns) {
|
||||
const match = error.match(pattern);
|
||||
if (match) {
|
||||
errorCode = match[1] || match[0];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有匹配到特定模式,使用500
|
||||
if (errorCode === error) {
|
||||
errorCode = 500;
|
||||
}
|
||||
let errorCode = error["error_code"];
|
||||
let errorMessage = error["error_message"];
|
||||
|
||||
if (!errorGroups[errorCode]) {
|
||||
errorGroups[errorCode] = [];
|
||||
}
|
||||
errorGroups[errorCode].push({ key, error });
|
||||
errorGroups[errorCode].push({ key, errorMessage });
|
||||
});
|
||||
|
||||
// 创建分组展示容器
|
||||
@@ -609,7 +592,7 @@ function showVerificationResultModal(data) {
|
||||
const keysList = document.createElement("div");
|
||||
keysList.className = "group-keys-list space-y-1";
|
||||
|
||||
keyErrorPairs.forEach(({ key, error }) => {
|
||||
keyErrorPairs.forEach(({ key, errorMessage }) => {
|
||||
const keyItem = document.createElement("div");
|
||||
keyItem.className = "flex flex-col items-start bg-gray-50 p-2 rounded border";
|
||||
|
||||
@@ -624,7 +607,7 @@ function showVerificationResultModal(data) {
|
||||
const detailsButton = document.createElement("button");
|
||||
detailsButton.className = "ml-2 px-2 py-0.5 bg-red-200 hover:bg-red-300 text-red-700 text-xs rounded transition-colors";
|
||||
detailsButton.innerHTML = '<i class="fas fa-info-circle mr-1"></i>详情';
|
||||
detailsButton.dataset.error = error;
|
||||
detailsButton.dataset.error = errorMessage;
|
||||
detailsButton.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
const button = e.currentTarget;
|
||||
@@ -984,7 +967,6 @@ function initializeGlobalBatchVerificationHandlers() {
|
||||
document.getElementById("verifyModal").classList.add("hidden");
|
||||
};
|
||||
|
||||
// executeVerifyAll 变为 initializeGlobalBatchVerificationHandlers 的局部函数
|
||||
async function executeVerifyAll(type) {
|
||||
closeVerifyModal();
|
||||
const keysToVerify = getSelectedKeys(type);
|
||||
@@ -1055,8 +1037,6 @@ function initializeGlobalBatchVerificationHandlers() {
|
||||
invalid_count: Object.keys(allFailedKeys).length
|
||||
});
|
||||
}
|
||||
// The confirmButton.onclick in showVerifyModal (defined earlier in initializeGlobalBatchVerificationHandlers)
|
||||
// will correctly reference this local executeVerifyAll due to closure.
|
||||
}
|
||||
|
||||
// --- 进度条模态框函数 ---
|
||||
@@ -2550,6 +2530,7 @@ function showVerifyModalForAllKeys(allKeys) {
|
||||
modalElement.classList.remove("hidden");
|
||||
}
|
||||
|
||||
|
||||
// 执行验证所有密钥
|
||||
async function executeVerifyAllKeys(allKeys) {
|
||||
closeVerifyModal();
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
</div>
|
||||
|
||||
<h2 class="text-3xl font-extrabold text-center text-gray-800 mb-8 animate-slide-down">
|
||||
<img src="/static/icons/logo.png" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
|
||||
<img src="{{ static_url('icons/logo.png') }}" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
|
||||
Gemini Balance
|
||||
</h2>
|
||||
|
||||
|
||||
@@ -4,21 +4,21 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>{% block title %}Gemini Balance{% endblock %}</title>
|
||||
<link rel="manifest" href="/static/manifest.json" />
|
||||
<link rel="manifest" href="{{ static_url('manifest.json') }}" />
|
||||
<meta name="theme-color" content="#4F46E5" />
|
||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black" />
|
||||
<meta name="apple-mobile-web-app-title" content="GBalance" />
|
||||
<link rel="icon" href="/static/icons/icon-192x192.png" />
|
||||
<link rel="icon" href="{{ static_url('icons/icon-192x192.png') }}" />
|
||||
<link
|
||||
href="/static/css/fonts.css"
|
||||
href="{{ static_url('css/fonts.css') }}"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"
|
||||
/>
|
||||
<script src="/static/js/tailwindcss.js"></script>
|
||||
<script src="{{ static_url('js/tailwindcss.js') }}"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -45,6 +45,179 @@ endblock %} {% block head_extra_styles %}
|
||||
.search-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
/* 移动端主容器布局 */
|
||||
.mobile-buttons-container {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
gap: 1rem !important;
|
||||
align-items: stretch !important;
|
||||
width: 100% !important;
|
||||
padding: 0 !important;
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
/* 移动端搜索控件布局优化 */
|
||||
.mobile-search-controls {
|
||||
grid-template-columns: 1fr !important;
|
||||
gap: 0.75rem !important;
|
||||
width: 100% !important;
|
||||
margin-bottom: 0.5rem !important;
|
||||
}
|
||||
|
||||
/* 按钮容器在移动端的布局 */
|
||||
.buttons-container-responsive {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
gap: 0.5rem !important;
|
||||
width: 100% !important;
|
||||
align-items: stretch !important;
|
||||
justify-content: stretch !important;
|
||||
}
|
||||
|
||||
/* 移动端所有按钮样式 */
|
||||
.buttons-container-responsive button {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
justify-content: center !important;
|
||||
text-align: center !important;
|
||||
min-width: 0 !important;
|
||||
flex-shrink: 0 !important;
|
||||
box-sizing: border-box !important;
|
||||
padding: 0.5rem 1rem !important;
|
||||
font-size: 0.875rem !important;
|
||||
white-space: nowrap !important;
|
||||
overflow: hidden !important;
|
||||
text-overflow: ellipsis !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 中等屏幕优化 */
|
||||
@media (max-width: 1024px) and (min-width: 769px) {
|
||||
.buttons-container-responsive {
|
||||
flex-wrap: wrap !important;
|
||||
justify-content: center !important;
|
||||
}
|
||||
|
||||
.buttons-container-responsive button {
|
||||
flex-shrink: 1 !important;
|
||||
min-width: 0 !important;
|
||||
padding-left: 0.75rem !important;
|
||||
padding-right: 0.75rem !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 小屏幕(手机)特殊优化 - 确保按钮在边框内 */
|
||||
@media (max-width: 640px) {
|
||||
/* 强制重写主容器布局 */
|
||||
.mobile-buttons-container {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
width: 100% !important;
|
||||
padding: 0 !important;
|
||||
margin: 0 !important;
|
||||
gap: 1rem !important;
|
||||
overflow: visible !important;
|
||||
}
|
||||
|
||||
/* 搜索区域在移动端占满宽度 */
|
||||
.mobile-search-controls {
|
||||
width: 100% !important;
|
||||
box-sizing: border-box !important;
|
||||
}
|
||||
|
||||
/* 按钮区域完全重新布局 */
|
||||
.buttons-container-responsive {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
gap: 0.5rem !important;
|
||||
padding: 0 !important;
|
||||
margin: 0 !important;
|
||||
box-sizing: border-box !important;
|
||||
overflow: visible !important;
|
||||
}
|
||||
|
||||
/* 所有按钮统一样式 */
|
||||
.buttons-container-responsive button {
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
box-sizing: border-box !important;
|
||||
padding: 0.5rem 1rem !important;
|
||||
margin: 0 !important;
|
||||
font-size: 0.875rem !important;
|
||||
line-height: 1.25rem !important;
|
||||
border-radius: 0.5rem !important;
|
||||
white-space: nowrap !important;
|
||||
overflow: hidden !important;
|
||||
text-overflow: ellipsis !important;
|
||||
flex-shrink: 0 !important;
|
||||
}
|
||||
|
||||
/* 特别针对清空全部按钮 */
|
||||
#deleteAllLogsBtn {
|
||||
background-color: #f87171 !important;
|
||||
border: 1px solid #f87171 !important;
|
||||
}
|
||||
|
||||
#deleteAllLogsBtn:hover {
|
||||
background-color: #ef4444 !important;
|
||||
border: 1px solid #ef4444 !important;
|
||||
}
|
||||
|
||||
/* 确保容器不会溢出父级 */
|
||||
.mobile-buttons-container,
|
||||
.mobile-buttons-container > *,
|
||||
.buttons-container-responsive,
|
||||
.buttons-container-responsive > * {
|
||||
max-width: 100% !important;
|
||||
box-sizing: border-box !important;
|
||||
}
|
||||
|
||||
/* 额外的安全边距控制 */
|
||||
.mobile-buttons-container .grid {
|
||||
padding-left: 0 !important;
|
||||
padding-right: 0 !important;
|
||||
margin-left: 0 !important;
|
||||
margin-right: 0 !important;
|
||||
}
|
||||
|
||||
/* 确保主内容区域有适当的内边距 */
|
||||
.rounded-xl.p-6 {
|
||||
padding-left: 1rem !important;
|
||||
padding-right: 1rem !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 超小屏幕额外优化 */
|
||||
@media (max-width: 480px) {
|
||||
.mobile-buttons-container {
|
||||
gap: 0.75rem !important;
|
||||
}
|
||||
|
||||
.buttons-container-responsive {
|
||||
gap: 0.4rem !important;
|
||||
}
|
||||
|
||||
.buttons-container-responsive button {
|
||||
padding: 0.4rem 0.8rem !important;
|
||||
font-size: 0.8rem !important;
|
||||
}
|
||||
|
||||
/* 主容器内边距进一步缩小 */
|
||||
.rounded-xl.p-6 {
|
||||
padding-left: 0.75rem !important;
|
||||
padding-right: 0.75rem !important;
|
||||
}
|
||||
|
||||
/* 确保清空全部按钮文字不会太挤 */
|
||||
#deleteAllLogsBtn i {
|
||||
margin-right: 0.25rem !important;
|
||||
}
|
||||
}
|
||||
|
||||
input[type="text"],
|
||||
@@ -586,7 +759,7 @@ endblock %} {% block head_extra_styles %}
|
||||
class="text-3xl font-extrabold text-center text-gray-800 mb-4"
|
||||
>
|
||||
<img
|
||||
src="/static/icons/logo.png"
|
||||
src="{{ static_url('icons/logo.png') }}"
|
||||
alt="Gemini Balance Logo"
|
||||
class="h-9 inline-block align-middle mr-2"
|
||||
/>
|
||||
@@ -636,10 +809,10 @@ endblock %} {% block head_extra_styles %}
|
||||
|
||||
<!-- 搜索与操作控件 -->
|
||||
<div
|
||||
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6"
|
||||
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6 mobile-buttons-container"
|
||||
>
|
||||
<div
|
||||
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full"
|
||||
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full mobile-search-controls"
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
@@ -684,7 +857,7 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3 flex-shrink-0">
|
||||
<div class="flex items-center gap-3 flex-shrink-0 buttons-container-responsive">
|
||||
<button
|
||||
id="searchBtn"
|
||||
class="flex items-center justify-center px-4 py-1.5 bg-blue-600 hover:bg-blue-700 text-white rounded-lg font-medium transition-all duration-200 shadow-sm hover:shadow-md whitespace-nowrap"
|
||||
@@ -1041,7 +1214,7 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %} {% block body_scripts %}
|
||||
<script src="/static/js/error_logs.js"></script>
|
||||
<script src="{{ static_url('js/error_logs.js') }}"></script>
|
||||
<script>
|
||||
// error_logs.html specific JS initialization (if any)
|
||||
// e.g., initialize date pickers or other elements if needed
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""
|
||||
通用工具函数模块
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from app.config.config import Settings
|
||||
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
|
||||
|
||||
helper_logger = logging.getLogger("app.utils")
|
||||
@@ -20,23 +23,25 @@ VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"
|
||||
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据
|
||||
|
||||
|
||||
Args:
|
||||
base64_string: 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith('data:'):
|
||||
if base64_string.startswith("data:"):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = DATA_URL_PATTERN
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
mime_type = (
|
||||
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
@@ -44,20 +49,20 @@ def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||
def convert_image_to_base64(url: str) -> str:
|
||||
"""
|
||||
将图片URL转换为base64编码
|
||||
|
||||
|
||||
Args:
|
||||
url: 图片URL
|
||||
|
||||
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: 如果获取图片失败
|
||||
"""
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||
img_data = base64.b64encode(response.content).decode("utf-8")
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
@@ -66,64 +71,66 @@ def convert_image_to_base64(url: str) -> str:
|
||||
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
|
||||
"""
|
||||
格式化JSON响应
|
||||
|
||||
|
||||
Args:
|
||||
data: 要格式化的数据
|
||||
indent: 缩进空格数
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化后的JSON字符串
|
||||
"""
|
||||
return json.dumps(data, indent=indent, ensure_ascii=False)
|
||||
|
||||
|
||||
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
|
||||
def parse_prompt_parameters(
|
||||
prompt: str, default_ratio: str = "1:1"
|
||||
) -> Tuple[str, int, str]:
|
||||
"""
|
||||
从prompt中解析参数
|
||||
|
||||
|
||||
支持的格式:
|
||||
- {n:数量} 例如: {n:2} 生成2张图片
|
||||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
default_ratio: 默认比例
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (清理后的提示文本, 图片数量, 比例)
|
||||
"""
|
||||
# 默认值
|
||||
n = 1
|
||||
aspect_ratio = default_ratio
|
||||
|
||||
|
||||
# 解析n参数
|
||||
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||
n_match = re.search(r"{n:(\d+)}", prompt)
|
||||
if n_match:
|
||||
n = int(n_match.group(1))
|
||||
if n < 1 or n > 4:
|
||||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||
prompt = prompt.replace(n_match.group(0), "").strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
|
||||
if ratio_match:
|
||||
aspect_ratio = ratio_match.group(1)
|
||||
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||
raise ValueError(
|
||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||
)
|
||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||
|
||||
prompt = prompt.replace(ratio_match.group(0), "").strip()
|
||||
|
||||
return prompt, n, aspect_ratio
|
||||
|
||||
|
||||
def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||
"""
|
||||
从Markdown文本中提取图片URL
|
||||
|
||||
|
||||
Args:
|
||||
text: Markdown文本
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 图片URL列表
|
||||
"""
|
||||
@@ -135,23 +142,22 @@ def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||
def is_valid_api_key(key: str) -> bool:
|
||||
"""
|
||||
检查API密钥格式是否有效
|
||||
|
||||
|
||||
Args:
|
||||
key: API密钥
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果密钥格式有效则返回True
|
||||
"""
|
||||
# 检查Gemini API密钥格式
|
||||
if key.startswith('AIza'):
|
||||
if key.startswith("AIza"):
|
||||
return len(key) >= 30
|
||||
|
||||
# 检查OpenAI API密钥格式
|
||||
if key.startswith('sk-'):
|
||||
return len(key) >= 30
|
||||
|
||||
return False
|
||||
|
||||
# 检查OpenAI API密钥格式
|
||||
if key.startswith("sk-"):
|
||||
return len(key) >= 30
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def redact_key_for_logging(key: str) -> str:
|
||||
@@ -177,15 +183,49 @@ def get_current_version(default_version: str = "0.0.0") -> str:
|
||||
"""Reads the current version from the VERSION file."""
|
||||
version_file = VERSION_FILE_PATH
|
||||
try:
|
||||
with version_file.open('r', encoding='utf-8') as f:
|
||||
with version_file.open("r", encoding="utf-8") as f:
|
||||
version = f.read().strip()
|
||||
if not version:
|
||||
helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
|
||||
helper_logger.warning(
|
||||
f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'."
|
||||
)
|
||||
return default_version
|
||||
return version
|
||||
except FileNotFoundError:
|
||||
helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
|
||||
helper_logger.warning(
|
||||
f"VERSION file not found at '{version_file}'. Using default version '{default_version}'."
|
||||
)
|
||||
return default_version
|
||||
except IOError as e:
|
||||
helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
|
||||
helper_logger.error(
|
||||
f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'."
|
||||
)
|
||||
return default_version
|
||||
|
||||
|
||||
def is_image_upload_configured(settings: Settings) -> bool:
|
||||
"""Return True only if a valid upload provider is selected and all required settings for that provider are present."""
|
||||
|
||||
provider = (getattr(settings, "UPLOAD_PROVIDER", "") or "").strip().lower()
|
||||
if provider == "smms":
|
||||
return bool(getattr(settings, "SMMS_SECRET_TOKEN", ""))
|
||||
if provider == "picgo":
|
||||
return bool(getattr(settings, "PICGO_API_KEY", ""))
|
||||
if provider == "aliyun_oss":
|
||||
return all(
|
||||
[
|
||||
getattr(settings, "OSS_ACCESS_KEY", ""),
|
||||
getattr(settings, "OSS_ACCESS_KEY_SECRET", ""),
|
||||
getattr(settings, "OSS_BUCKET_NAME", ""),
|
||||
getattr(settings, "OSS_ENDPOINT", ""),
|
||||
getattr(settings, "OSS_REGION", "")
|
||||
]
|
||||
)
|
||||
if provider == "cloudflare_imgbed":
|
||||
return all(
|
||||
[
|
||||
getattr(settings, "CLOUDFLARE_IMGBED_URL", ""),
|
||||
getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""),
|
||||
]
|
||||
)
|
||||
return False
|
||||
|
||||
127
app/utils/static_version.py
Normal file
127
app/utils/static_version.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
静态资源版本控制工具
|
||||
用于给CSS和JS文件添加版本参数,避免浏览器缓存问题
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.utils.helpers import get_current_version
|
||||
|
||||
|
||||
class StaticVersionManager:
|
||||
"""静态资源版本管理器"""
|
||||
|
||||
def __init__(self, static_dir: str = "app/static"):
|
||||
self.static_dir = Path(static_dir)
|
||||
self._version_cache: Dict[str, str] = {}
|
||||
self._use_file_hash = True # 是否使用文件哈希作为版本号
|
||||
|
||||
def get_version_for_file(self, file_path: str) -> str:
|
||||
"""
|
||||
获取文件的版本号
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径,如 'css/fonts.css'
|
||||
|
||||
Returns:
|
||||
版本号字符串
|
||||
"""
|
||||
if self._use_file_hash:
|
||||
return self._get_file_hash_version(file_path)
|
||||
else:
|
||||
return self._get_app_version()
|
||||
|
||||
def _get_file_hash_version(self, file_path: str) -> str:
|
||||
"""基于文件内容生成哈希版本号"""
|
||||
# 如果已经缓存过,直接返回
|
||||
if file_path in self._version_cache:
|
||||
return self._version_cache[file_path]
|
||||
|
||||
full_path = self.static_dir / file_path
|
||||
|
||||
if not full_path.exists():
|
||||
# 文件不存在,使用应用版本号作为fallback
|
||||
version = self._get_app_version()
|
||||
else:
|
||||
try:
|
||||
# 读取文件内容并计算MD5哈希
|
||||
with open(full_path, "rb") as f:
|
||||
content = f.read()
|
||||
hash_object = hashlib.md5(content)
|
||||
version = hash_object.hexdigest()[:8] # 取前8位
|
||||
except Exception:
|
||||
# 读取失败,使用应用版本号作为fallback
|
||||
version = self._get_app_version()
|
||||
|
||||
# 缓存结果
|
||||
self._version_cache[file_path] = version
|
||||
return version
|
||||
|
||||
def _get_app_version(self) -> str:
|
||||
"""获取应用程序版本号"""
|
||||
try:
|
||||
return get_current_version().replace(".", "")
|
||||
except Exception:
|
||||
# 如果获取版本失败,使用时间戳
|
||||
return str(int(time.time()))
|
||||
|
||||
def get_versioned_url(self, file_path: str) -> str:
|
||||
"""
|
||||
获取带版本参数的URL
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径
|
||||
|
||||
Returns:
|
||||
带版本参数的URL
|
||||
"""
|
||||
version = self.get_version_for_file(file_path)
|
||||
return f"/static/{file_path}?v={version}"
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空版本缓存"""
|
||||
self._version_cache.clear()
|
||||
|
||||
|
||||
# 全局实例
|
||||
_static_version_manager = StaticVersionManager()
|
||||
|
||||
|
||||
def get_static_url(file_path: str) -> str:
|
||||
"""
|
||||
获取静态资源的版本化URL
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径
|
||||
|
||||
Returns:
|
||||
带版本参数的完整URL
|
||||
|
||||
Example:
|
||||
get_static_url('css/fonts.css') -> '/static/css/fonts.css?v=a1b2c3d4'
|
||||
get_static_url('js/config_editor.js') -> '/static/js/config_editor.js?v=e5f6g7h8'
|
||||
"""
|
||||
return _static_version_manager.get_versioned_url(file_path)
|
||||
|
||||
|
||||
def clear_static_cache():
|
||||
"""清空静态资源版本缓存"""
|
||||
_static_version_manager.clear_cache()
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_cached_static_url(file_path: str) -> str:
|
||||
"""
|
||||
获取缓存的静态资源URL(用于开发环境)
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径
|
||||
|
||||
Returns:
|
||||
带版本参数的完整URL
|
||||
"""
|
||||
return get_static_url(file_path)
|
||||
@@ -2,6 +2,12 @@ import requests
|
||||
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from app.log.logger import get_image_create_logger
|
||||
|
||||
class UploadErrorType(Enum):
|
||||
"""上传错误类型枚举"""
|
||||
@@ -179,9 +185,22 @@ class PicGoUploader(ImageUploader):
|
||||
"""
|
||||
try:
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"X-API-Key": self.api_key
|
||||
}
|
||||
headers = {}
|
||||
|
||||
# 构建请求URL
|
||||
request_url = self.api_url
|
||||
|
||||
# 判断是否为默认PicGo URL,如果是则使用header认证,否则使用URL参数认证
|
||||
if self.api_url == "https://www.picgo.net/api/1/upload":
|
||||
headers["X-API-Key"] = self.api_key
|
||||
else:
|
||||
# 对于自定义URL,将API key作为查询参数添加到URL中
|
||||
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
|
||||
parsed_url = urlparse(request_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params["key"] = self.api_key
|
||||
new_query = urlencode(query_params, doseq=True)
|
||||
request_url = urlunparse(parsed_url._replace(query=new_query))
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
@@ -190,7 +209,7 @@ class PicGoUploader(ImageUploader):
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
request_url,
|
||||
headers=headers,
|
||||
files=files
|
||||
)
|
||||
@@ -201,6 +220,34 @@ class PicGoUploader(ImageUploader):
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# 处理自定义PicGo服务器的响应格式
|
||||
if "success" in result and "result" in result:
|
||||
# 自定义PicGo服务器格式: {"success": true, "result": ["url"]}
|
||||
if result["success"]:
|
||||
image_url = result["result"][0] if result["result"] and len(result["result"]) > 0 else ""
|
||||
image_metadata = ImageMetadata(
|
||||
width=0,
|
||||
height=0,
|
||||
filename=filename,
|
||||
size=0,
|
||||
url=image_url,
|
||||
delete_url=None
|
||||
)
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
code="success",
|
||||
message="Upload success",
|
||||
data=image_metadata
|
||||
)
|
||||
else:
|
||||
raise UploadError(
|
||||
message="Upload failed",
|
||||
error_type=UploadErrorType.SERVER_ERROR,
|
||||
status_code=400,
|
||||
details=result
|
||||
)
|
||||
|
||||
# 处理官方PicGo服务器的响应格式
|
||||
# 验证上传是否成功
|
||||
if result.get("status_code") != 200:
|
||||
error_message = "Upload failed"
|
||||
@@ -259,6 +306,191 @@ class PicGoUploader(ImageUploader):
|
||||
)
|
||||
|
||||
|
||||
class AliyunOSSUploader(ImageUploader):
|
||||
"""阿里云OSS图片上传器"""
|
||||
|
||||
def __init__(self, access_key: str, access_key_secret: str, bucket_name: str,
|
||||
endpoint: str, region: str, use_internal: bool = False):
|
||||
"""
|
||||
初始化阿里云OSS上传器
|
||||
|
||||
Args:
|
||||
access_key: OSS访问密钥ID
|
||||
access_key_secret: OSS访问密钥
|
||||
bucket_name: OSS存储桶名称
|
||||
endpoint: OSS端点地址
|
||||
region: OSS区域
|
||||
use_internal: 是否使用内网端点
|
||||
"""
|
||||
self.access_key = access_key
|
||||
self.access_key_secret = access_key_secret
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.region = region
|
||||
self.use_internal = use_internal
|
||||
self.logger = get_image_create_logger()
|
||||
|
||||
# 构建请求URL
|
||||
if not endpoint.startswith(('http://', 'https://')):
|
||||
self.base_url = f"https://{bucket_name}.{endpoint}"
|
||||
else:
|
||||
self.base_url = f"{endpoint}/{bucket_name}"
|
||||
|
||||
self.logger.info(f"Initialized AliyunOSSUploader for bucket: {bucket_name}, region: {region}")
|
||||
|
||||
def _sign_request(self, method: str, path: str, headers: dict, content: bytes = b'') -> dict:
|
||||
"""
|
||||
为OSS请求生成签名
|
||||
|
||||
Args:
|
||||
method: HTTP方法
|
||||
path: 请求路径
|
||||
headers: 请求头
|
||||
content: 请求内容
|
||||
|
||||
Returns:
|
||||
包含签名的请求头
|
||||
"""
|
||||
# 计算Content-MD5
|
||||
content_md5 = base64.b64encode(hashlib.md5(content).digest()).decode('utf-8') if content else ''
|
||||
|
||||
# 设置日期
|
||||
date = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
|
||||
# 更新headers
|
||||
headers['Date'] = date
|
||||
if content_md5:
|
||||
headers['Content-MD5'] = content_md5
|
||||
headers['Content-Type'] = headers.get('Content-Type', 'image/png')
|
||||
|
||||
# 构建CanonicalizedOSSHeaders
|
||||
oss_headers = []
|
||||
for key, value in sorted(headers.items()):
|
||||
if key.lower().startswith('x-oss-'):
|
||||
oss_headers.append(f"{key.lower()}:{value}")
|
||||
canonicalized_oss_headers = '\n'.join(oss_headers)
|
||||
if canonicalized_oss_headers:
|
||||
canonicalized_oss_headers += '\n'
|
||||
|
||||
# 构建CanonicalizedResource
|
||||
canonicalized_resource = f"/{self.bucket_name}{path}"
|
||||
|
||||
# 构建StringToSign
|
||||
string_to_sign = f"{method}\n{content_md5}\n{headers.get('Content-Type', '')}\n{date}\n{canonicalized_oss_headers}{canonicalized_resource}"
|
||||
|
||||
# 计算签名
|
||||
signature = base64.b64encode(
|
||||
hmac.new(
|
||||
self.access_key_secret.encode('utf-8'),
|
||||
string_to_sign.encode('utf-8'),
|
||||
hashlib.sha1
|
||||
).digest()
|
||||
).decode('utf-8')
|
||||
|
||||
# 添加Authorization头
|
||||
headers['Authorization'] = f"OSS {self.access_key}:{signature}"
|
||||
|
||||
return headers
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
"""
|
||||
上传图片到阿里云OSS
|
||||
|
||||
Args:
|
||||
file: 图片文件二进制数据
|
||||
filename: 文件名(将作为OSS对象的key)
|
||||
|
||||
Returns:
|
||||
UploadResponse: 上传响应对象
|
||||
|
||||
Raises:
|
||||
UploadError: 上传失败时抛出异常
|
||||
"""
|
||||
# 记录开始上传的日志
|
||||
self.logger.info(f"Starting OSS upload for file: {filename}, size: {len(file)} bytes")
|
||||
|
||||
try:
|
||||
# 构建对象路径
|
||||
object_key = f"/{filename}"
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'image/png',
|
||||
'x-oss-object-acl': 'public-read' # 设置为公共读
|
||||
}
|
||||
|
||||
# 签名请求
|
||||
signed_headers = self._sign_request('PUT', object_key, headers, file)
|
||||
|
||||
# 构建完整URL
|
||||
upload_url = f"{self.base_url}{object_key}"
|
||||
self.logger.debug(f"OSS upload URL: {upload_url}")
|
||||
|
||||
# 发送请求
|
||||
response = requests.put(
|
||||
upload_url,
|
||||
data=file,
|
||||
headers=signed_headers
|
||||
)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_msg = f"OSS upload failed with status {response.status_code}, response: {response.text}"
|
||||
self.logger.error(f"OSS upload failed for {filename}: {error_msg}")
|
||||
raise UploadError(
|
||||
message=f"OSS upload failed with status {response.status_code}",
|
||||
error_type=UploadErrorType.SERVER_ERROR,
|
||||
status_code=response.status_code,
|
||||
details={'response': response.text}
|
||||
)
|
||||
|
||||
# 构建访问URL
|
||||
if self.endpoint.startswith(('http://', 'https://')):
|
||||
access_url = f"{self.endpoint}/{self.bucket_name}{object_key}"
|
||||
else:
|
||||
access_url = f"https://{self.bucket_name}.{self.endpoint}{object_key}"
|
||||
|
||||
# 构建图片元数据
|
||||
image_metadata = ImageMetadata(
|
||||
width=0, # OSS PUT不返回图片尺寸
|
||||
height=0,
|
||||
filename=filename,
|
||||
size=len(file),
|
||||
url=access_url,
|
||||
delete_url=None # OSS需要单独的删除操作
|
||||
)
|
||||
|
||||
# 记录上传成功的日志
|
||||
self.logger.info(f"OSS upload successful for {filename}, URL: {access_url}")
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
code="success",
|
||||
message="Upload to Aliyun OSS success",
|
||||
data=image_metadata
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"OSS upload request failed: {str(e)}"
|
||||
self.logger.error(f"OSS upload request failed for {filename}: {error_msg}")
|
||||
raise UploadError(
|
||||
message=error_msg,
|
||||
error_type=UploadErrorType.NETWORK_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
except UploadError:
|
||||
# UploadError 已经被记录了,直接重新抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"OSS upload failed: {str(e)}"
|
||||
self.logger.error(f"OSS upload unexpected error for {filename}: {error_msg}")
|
||||
raise UploadError(
|
||||
message=error_msg,
|
||||
error_type=UploadErrorType.UNKNOWN,
|
||||
original_error=e
|
||||
)
|
||||
|
||||
|
||||
class CloudFlareImgBedUploader(ImageUploader):
|
||||
"""CloudFlare图床上传器"""
|
||||
|
||||
@@ -389,7 +621,7 @@ class ImageUploaderFactory:
|
||||
credentials["secret_key"]
|
||||
)
|
||||
elif provider == "picgo":
|
||||
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
|
||||
api_url = credentials.get("api_url") or "https://www.picgo.net/api/1/upload"
|
||||
return PicGoUploader(credentials["api_key"], api_url)
|
||||
elif provider == "cloudflare_imgbed":
|
||||
return CloudFlareImgBedUploader(
|
||||
@@ -397,4 +629,13 @@ class ImageUploaderFactory:
|
||||
credentials["base_url"],
|
||||
credentials.get("upload_folder", ""),
|
||||
)
|
||||
elif provider == "aliyun_oss":
|
||||
return AliyunOSSUploader(
|
||||
credentials["access_key"],
|
||||
credentials["access_key_secret"],
|
||||
credentials["bucket_name"],
|
||||
credentials["endpoint"],
|
||||
credentials["region"],
|
||||
credentials.get("use_internal", False)
|
||||
)
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
@@ -36,4 +36,13 @@ services:
|
||||
interval: 10s # 每隔10秒检查一次
|
||||
timeout: 5s # 每次检查的超时时间为5秒
|
||||
retries: 3 # 重试3次失败后标记为 unhealthy
|
||||
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
|
||||
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
|
||||
# adminer:
|
||||
# image: adminer:latest
|
||||
# container_name: gemini-balance-adminer
|
||||
# restart: unless-stopped
|
||||
# ports:
|
||||
# - "8080:8080"
|
||||
# depends_on:
|
||||
# mysql:
|
||||
# condition: service_healthy
|
||||
Reference in New Issue
Block a user