Compare commits
142 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67f87989db | ||
|
|
17738b39a7 | ||
|
|
1e5312f96b | ||
|
|
548e69d87f | ||
|
|
90161a1f47 | ||
|
|
9ea3452b17 | ||
|
|
11e45fca37 | ||
|
|
c85fe979e5 | ||
|
|
a47edf1661 | ||
|
|
814a2e66c0 | ||
|
|
a7d548a849 | ||
|
|
b6a54190ed | ||
|
|
920228d3aa | ||
|
|
f1f568afca | ||
|
|
30bf666a57 | ||
|
|
c65d5244d6 | ||
|
|
4ad18e43ef | ||
|
|
f17cd66127 | ||
|
|
e1c068ed9e | ||
|
|
b86eac839d | ||
|
|
83252cbf33 | ||
|
|
12f6665519 | ||
|
|
1ff494416b | ||
|
|
8ec1d16e9d | ||
|
|
f13a4fba5f | ||
|
|
d4a3ed3a57 | ||
|
|
a6a1e7fb52 | ||
|
|
c01bc242aa | ||
|
|
ab06627d3f | ||
|
|
631d054d9e | ||
|
|
d835085e61 | ||
|
|
7c3ebe7e8b | ||
|
|
7e76d07e28 | ||
|
|
d21fb6c455 | ||
|
|
56f6f5e198 | ||
|
|
929592bbc4 | ||
|
|
2225a40bbe | ||
|
|
3480fa3b0f | ||
|
|
d7113f5fc4 | ||
|
|
2072f54ca1 | ||
|
|
7c9b721164 | ||
|
|
83ce50975a | ||
|
|
7da9110704 | ||
|
|
e9d19de7c6 | ||
|
|
e822831178 | ||
|
|
775930edce | ||
|
|
cb40848c04 | ||
|
|
7098c8755f | ||
|
|
705d602dee | ||
|
|
cd257a9406 | ||
|
|
cd54650431 | ||
|
|
a5602c602e | ||
|
|
dd70fd4c44 | ||
|
|
dbe50628b3 | ||
|
|
83ed0527d3 | ||
|
|
ab31f4bb98 | ||
|
|
734a8c4bc4 | ||
|
|
fea3af4692 | ||
|
|
9302cf295e | ||
|
|
b4f040e77a | ||
|
|
defabf4355 | ||
|
|
f3ed3168e4 | ||
|
|
01765b1731 | ||
|
|
f83f0fa768 | ||
|
|
a7085964e8 | ||
|
|
d3cd2856b7 | ||
|
|
353d22cc70 | ||
|
|
eb96474c19 | ||
|
|
0c48a2d74d | ||
|
|
1b23d574a5 | ||
|
|
ebc5dc571b | ||
|
|
9a7a1d7c2f | ||
|
|
c99e090ea9 | ||
|
|
eb311de0c2 | ||
|
|
c254077a66 | ||
|
|
ef4a528611 | ||
|
|
f593d97381 | ||
|
|
053ef631c4 | ||
|
|
075d20c62d | ||
|
|
0768aed179 | ||
|
|
c2eac24175 | ||
|
|
1c6dabcea7 | ||
|
|
76937aa24f | ||
|
|
b96ce8f15a | ||
|
|
87d60117c5 | ||
|
|
a53a30fd38 | ||
|
|
98e7fb62d5 | ||
|
|
6a59b4f847 | ||
|
|
d1ba2c4ae9 | ||
|
|
0693a5c245 | ||
|
|
742db744d1 | ||
|
|
12a84921c1 | ||
|
|
73e98a185d | ||
|
|
73a7c81f85 | ||
|
|
86dba93974 | ||
|
|
439165bc6c | ||
|
|
0dd9dd5380 | ||
|
|
aea2f39952 | ||
|
|
f7cfc8952f | ||
|
|
7b4652c802 | ||
|
|
51bb71bdb5 | ||
|
|
69261e98de | ||
|
|
f05d67939f | ||
|
|
d94d24f96c | ||
|
|
0f28173b0e | ||
|
|
af310ffb6b | ||
|
|
169488851f | ||
|
|
a7dc05a359 | ||
|
|
d0cc48ad63 | ||
|
|
5fc59a00d0 | ||
|
|
619f81cce4 | ||
|
|
a6c162b223 | ||
|
|
4c2f3ed9b0 | ||
|
|
ba38f14cd8 | ||
|
|
47bf47d90e | ||
|
|
cc36ba4c9e | ||
|
|
baf643e884 | ||
|
|
360bc9e48d | ||
|
|
c0a27d0542 | ||
|
|
84052a2179 | ||
|
|
2e7ecd88b5 | ||
|
|
0b1f3dfc04 | ||
|
|
c691c7c1cf | ||
|
|
97db7eebf1 | ||
|
|
60dca70fcd | ||
|
|
89b9f7919a | ||
|
|
a8dc98ab6a | ||
|
|
b3a057b6ba | ||
|
|
b14bb93d8f | ||
|
|
8ca62707ea | ||
|
|
21444ed6c7 | ||
|
|
ba292dbedd | ||
|
|
6ba58ce9d1 | ||
|
|
16f16a3ae9 | ||
|
|
26dcb64687 | ||
|
|
df88492113 | ||
|
|
851bb9c09b | ||
|
|
0cac178572 | ||
|
|
67c85c994a | ||
|
|
ee979dd568 | ||
|
|
e79a1ba56c | ||
|
|
016e6e06ee |
48
.env.example
@@ -1,13 +1,34 @@
|
||||
# 数据库配置
|
||||
DATABASE_TYPE=mysql
|
||||
#SQLITE_DATABASE=default_db
|
||||
MYSQL_HOST=gemini-balance-mysql
|
||||
#MYSQL_SOCKET=/run/mysqld/mysqld.sock
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_USER=gemini
|
||||
MYSQL_PASSWORD=change_me
|
||||
MYSQL_DATABASE=default_db
|
||||
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
||||
ALLOWED_TOKENS=["sk-123456"]
|
||||
# AUTH_TOKEN=sk-123456
|
||||
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
MODEL_IMAGE=["gemini-2.0-flash-exp"]
|
||||
AUTH_TOKEN=sk-123456
|
||||
TEST_MODEL=gemini-1.5-flash
|
||||
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
SHOW_SEARCH_LINK=true
|
||||
SHOW_THINKING_PROCESS=true
|
||||
BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
MAX_FAILURES=10
|
||||
MAX_RETRIES=3
|
||||
CHECK_INTERVAL_HOURS=1
|
||||
TIMEZONE=Asia/Shanghai
|
||||
# 请求超时时间(秒)
|
||||
TIME_OUT=300
|
||||
# 代理服务器配置 (支持 http 和 socks5)
|
||||
# 示例: PROXIES=["http://user:pass@host:port", "socks5://host:port"]
|
||||
PROXIES=[]
|
||||
#########################image_generate 相关配置###########################
|
||||
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
|
||||
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||
@@ -18,9 +39,30 @@ CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||
##########################################################################
|
||||
#########################stream_optimizer 相关配置########################
|
||||
STREAM_OPTIMIZER_ENABLED=false
|
||||
STREAM_MIN_DELAY=0.016
|
||||
STREAM_MAX_DELAY=0.024
|
||||
STREAM_SHORT_TEXT_THRESHOLD=10
|
||||
STREAM_LONG_TEXT_THRESHOLD=50
|
||||
STREAM_CHUNK_SIZE=5
|
||||
##########################################################################
|
||||
######################### 日志配置 #######################################
|
||||
# 日志级别 (debug, info, warning, error, critical),默认为 info
|
||||
LOG_LEVEL=info
|
||||
# 是否开启自动删除错误日志
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED=true
|
||||
# 自动删除多少天前的错误日志 (1, 7, 30)
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS=7
|
||||
# 是否开启自动删除请求日志
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED=false
|
||||
# 自动删除多少天前的请求日志 (1, 7, 30)
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS=30
|
||||
##########################################################################
|
||||
|
||||
# 假流式配置 (Fake Streaming Configuration)
|
||||
FAKE_STREAM_ENABLED=True # 是否启用假流式输出
|
||||
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS=5 # 假流式发送空数据的间隔时间(秒)
|
||||
|
||||
# 安全设置 (JSON 字符串格式)
|
||||
# 注意:这里的示例值可能需要根据实际模型支持情况调整
|
||||
SAFETY_SETTINGS='[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]'
|
||||
|
||||
24
.github/workflows/docker-publish.yml
vendored
@@ -2,8 +2,6 @@ name: Docker Image CI
|
||||
|
||||
on:
|
||||
push:
|
||||
# branches: [ "main" ]
|
||||
tags: [ 'v*.*.*' ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
@@ -43,20 +41,30 @@ jobs:
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
# https://github.com/docker/metadata-action/tree/v5/?tab=readme-ov-file#semver
|
||||
# Event: push, Ref: refs/head/main, Tags: main
|
||||
# Event: push tag, Ref: refs/tags/v1.2.3, Tags: 1.2.3, 1.2, 1, latest
|
||||
# Event: push tag, Ref: refs/tags/v2.0.8-rc1, Tags: 2.0.8-rc1
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha,format=long
|
||||
type=semver,pattern={{major}}
|
||||
labels: |
|
||||
org.opencontainers.image.description=OpenAI API Compatible Server
|
||||
org.opencontainers.image.source=${{ github.event.repository.html_url }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
file: Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
load: false
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
cache-from: type=gha,scope=${{ github.workflow }}
|
||||
cache-to: type=gha,scope=${{ github.workflow }}
|
||||
|
||||
3
.gitignore
vendored
@@ -257,4 +257,5 @@ $RECYCLE.BIN/
|
||||
|
||||
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
|
||||
|
||||
tests/
|
||||
tests/
|
||||
default_db
|
||||
@@ -4,6 +4,7 @@ WORKDIR /app
|
||||
|
||||
# 复制所需文件到容器中
|
||||
COPY ./requirements.txt /app
|
||||
COPY ./VERSION /app
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY ./app /app/app
|
||||
@@ -11,7 +12,8 @@ ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
|
||||
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
||||
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
17
LICENSE
Normal file
@@ -0,0 +1,17 @@
|
||||
知识共享署名-非商业性使用 4.0 国际 (CC BY-NC 4.0) 协议
|
||||
|
||||
您可以自由地:
|
||||
- 共享 — 在任何媒介以任何形式复制、发行本作品
|
||||
- 演绎 — 修改、转换或以本作品为基础进行创作
|
||||
|
||||
惟须遵守下列条件:
|
||||
- 署名 — 您必须给出适当的署名,提供指向本协议的链接,并指明是否(对原作)作了修改。您可以以任何合理方式进行,但不得以任何方式暗示许可方认可您或您的使用。
|
||||
- 非商业性使用 — 您不得将本作品用于商业目的,包括但不限于任何形式的商业倒卖、SaaS、API 付费接口、二次销售、打包出售、收费分发或其他直接或间接盈利行为。
|
||||
|
||||
如需商业授权,请联系原作者获得书面许可。违者将承担相应法律责任。
|
||||
|
||||
Creative Commons Attribution-NonCommercial 4.0 International Public License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
||||
|
||||
Full license text: https://creativecommons.org/licenses/by-nc/4.0/legalcode
|
||||
611
README.md
@@ -1,188 +1,81 @@
|
||||
# 🚀 FastAPI OpenAI (Gemini) 代理服务
|
||||
# Gemini Balance - Gemini API 代理和负载均衡器
|
||||
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
> ⚠️ 本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
|
||||
## 📝 项目简介
|
||||
> 本人从未在各个平台售卖服务,如有遇到售卖此服务者,那一定是倒卖狗,大家切记不要上当受骗。
|
||||
|
||||
本项目是一个基于 FastAPI 框架开发的高性能、易于部署的Gemini OpenAI兼容 和 Gemini API 代理服务。它不仅兼容 OpenAI 的 API 接口,还支持 Google 的 Gemini 原生接口。该代理服务内置了多 API Key 轮询、负载均衡、自动重试、访问控制(Bearer Token 认证)、流式响应等功能,旨在简化 AI 应用的开发和部署流程。
|
||||
[](https://www.python.org/)
|
||||
[](https://fastapi.tiangolo.com/)
|
||||
[](https://www.uvicorn.org/)
|
||||
[](https://t.me/+soaHax5lyI0wZDVl)
|
||||
> 交流群:https://t.me/+soaHax5lyI0wZDVl
|
||||
|
||||
**核心功能与优势:**
|
||||
## 项目简介
|
||||
|
||||
- **多协议支持**: 无缝切换 OpenAI兼容 和 Gemini 协议。
|
||||
- **智能 API Key 管理**: 自动轮询多个 API Key,实现负载均衡和故障转移。
|
||||
- **安全访问控制**: 使用 Bearer Token 进行身份验证,保护 API 访问。
|
||||
- **流式响应支持**: 提供实时的流式数据传输,提升用户体验。
|
||||
- **内置工具支持**: 支持代码执行和 Google 搜索等工具, 丰富模型功能 (可选)。
|
||||
- **灵活配置**: 通过环境变量或 `.env` 文件轻松配置。
|
||||
- **易于部署**: 提供 Docker 一键部署,也支持手动部署。
|
||||
- **健康检查**: 提供健康检查接口,方便监控服务状态。
|
||||
- **图片生成支持**: 支持使用OpenAI的DALL-E模型生成图片
|
||||
Gemini Balance 是一个基于 Python FastAPI 构建的应用程序,旨在提供 Google Gemini API 的代理和负载均衡功能。它允许您管理多个 Gemini API Key,并通过简单的配置实现 Key 的轮询、认证、模型过滤和状态监控。此外,项目还集成了图像生成和多种图床上传功能,并支持 OpenAI API 格式的代理。
|
||||
|
||||
## 🛠️ 技术栈
|
||||
**项目结构:**
|
||||
|
||||
- **FastAPI**: 高性能 Web 框架。
|
||||
- **Python 3.9+**: 编程语言。
|
||||
- **Pydantic**: 数据验证和设置管理。
|
||||
- **httpx**: 异步 HTTP 客户端。
|
||||
- **uvicorn**: ASGI 服务器。
|
||||
- **Docker**: 容器化部署 (可选)。
|
||||
```plaintext
|
||||
app/
|
||||
├── config/ # 配置管理
|
||||
├── core/ # 核心应用逻辑 (FastAPI 实例创建, 中间件等)
|
||||
├── database/ # 数据库模型和连接
|
||||
├── domain/ # 业务领域对象 (可选)
|
||||
├── exception/ # 自定义异常
|
||||
├── handler/ # 请求处理器 (可选, 或在 router 中处理)
|
||||
├── log/ # 日志配置
|
||||
├── main.py # 应用入口
|
||||
├── middleware/ # FastAPI 中间件
|
||||
├── router/ # API 路由 (Gemini, OpenAI, 状态页等)
|
||||
├── scheduler/ # 定时任务 (如 Key 状态检查)
|
||||
├── service/ # 业务逻辑服务 (聊天, Key 管理, 统计等)
|
||||
├── static/ # 静态文件 (CSS, JS)
|
||||
├── templates/ # HTML 模板 (如 Key 状态页)
|
||||
├── utils/ # 工具函数
|
||||
```
|
||||
|
||||
## ✨ 功能亮点
|
||||
|
||||
* **多 Key 负载均衡**: 支持配置多个 Gemini API Key (`API_KEYS`),自动按顺序轮询使用,提高可用性和并发能力。
|
||||
* **可视化配置即时生效**: 通过管理后台修改配置后,无需重启服务即可生效,切记要点击保存才会生效。
|
||||

|
||||
* **双协议API 兼容**: 同时支持 Gemini 和 OpenAI 格式的 CHAT API 请求转发。
|
||||
|
||||
```palintext
|
||||
openai baseurl `http://localhost:8000(/hf)/v1`
|
||||
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
|
||||
```
|
||||
|
||||
* **支持图文对话和修改图片**: `IMAGE_MODELS`配置哪个模型可以图文对话和修图的功能,实际调用的时候,用 `配置模型-image`这个模型名对话使用该功能。
|
||||

|
||||

|
||||
* **支持联网搜索**: 支持联网搜索,`SEARCH_MODELS`配置哪些模型可以联网搜索,实际调用的时候,用 `配置模型-search`这个模型名对话使用该功能
|
||||

|
||||
* **Key 状态监控**: 提供 `/keys_status` 页面(需要认证),实时查看各 Key 的状态和使用情况。
|
||||

|
||||
* **详细的日志记录**: 提供详细的错误日志,方便排查。
|
||||

|
||||

|
||||

|
||||
* **支持自定义gemini代理**: 支持自定义gemini代理,比如自行在deno或者cloudflare上搭建gemini代理
|
||||
* **openai画图接口兼容**: 将`imagen-3.0-generate-002`模型接口改造成openai画图接口,支持客户端调用。
|
||||
* **灵活的添加密钥方式**: 灵活的添加密钥方式,采用正则匹配`gemini_key`,密钥去重
|
||||

|
||||
* **兼容openai格式embeddings接口**:完美适配openai格式的`embeddings`接口,可用于本地文档向量化。
|
||||
* **流式响应优化**: 可选的流式输出优化器 (`STREAM_OPTIMIZER_ENABLED`),改善长文本流式响应的体验。
|
||||
* **失败重试与 Key 管理**: 自动处理 API 请求失败,进行重试 (`MAX_RETRIES`),并在 Key 失效次数过多时自动禁用 (`MAX_FAILURES`),定时检查恢复 (`CHECK_INTERVAL_HOURS`)。
|
||||
* **Docker 支持**: 支持AMD,ARM架构的docker部署,也可自行构建docker镜像。
|
||||
>镜像地址: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **模型列表自动维护**: 支持openai和gemini模型列表获取,与newapi自动获取模型列表完美兼容,无需手动填写。
|
||||
* **支持移除不使用的模型**: 默认提供的模型太多,很多用不上,可以通过`FILTERED_MODELS`过滤掉。
|
||||
* **代理支持**: 支持配置 HTTP/SOCKS5 代理服务器 (`PROXIES`),用于访问 Gemini API,方便在特殊网络环境下使用。支持批量添加代理。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 环境要求
|
||||
### 自行构建 Docker (推荐)
|
||||
|
||||
- Python 3.9 或更高版本
|
||||
- Docker (可选,推荐用于生产环境)
|
||||
|
||||
### 📦 安装与配置
|
||||
|
||||
1. **克隆项目**:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/snailyp/gemini-balance.git
|
||||
cd gemini-balance
|
||||
```
|
||||
|
||||
2. **安装依赖**:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. **配置**:
|
||||
|
||||
创建 `.env` 文件,并按以下分类配置环境变量:
|
||||
|
||||
```env
|
||||
# 基础配置
|
||||
BASE_URL="https://generativelanguage.googleapis.com/v1beta" # Gemini API 基础 URL,默认无需修改
|
||||
MAX_FAILURES=3 # 允许单个key失败的次数,默认3次
|
||||
|
||||
# 认证与安全配置
|
||||
API_KEYS=["your-gemini-api-key-1", "your-gemini-api-key-2"] # Gemini API 密钥列表,用于负载均衡
|
||||
ALLOWED_TOKENS=["your-access-token-1", "your-access-token-2"] # 允许访问的 Token 列表
|
||||
AUTH_TOKEN="" # 超级管理员token,具有所有权限,默认使用 ALLOWED_TOKENS 的第一个
|
||||
|
||||
# 模型功能配置
|
||||
MODEL_SEARCH=["gemini-2.0-flash-exp"] # 支持搜索功能的模型列表
|
||||
TOOLS_CODE_EXECUTION_ENABLED=false # 是否启用代码执行工具,默认false
|
||||
SHOW_SEARCH_LINK=true # 是否在响应中显示搜索结果链接,默认true
|
||||
SHOW_THINKING_PROCESS=true # 是否显示模型思考过程,默认true
|
||||
|
||||
# 图片生成配置
|
||||
PAID_KEY="your-paid-api-key" # 付费版API Key,用于图片生成等高级功能
|
||||
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型,默认使用imagen-3.0
|
||||
|
||||
# 图片上传配置
|
||||
UPLOAD_PROVIDER="smms" # 图片上传提供商,目前支持smms、picgo、cloudflare_imgbed
|
||||
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
|
||||
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
|
||||
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key,可在项目后台设置,若无鉴权则可直接置空。
|
||||
|
||||
# stream_optimizer 相关配置
|
||||
STREAM_MIN_DELAY=0.016
|
||||
STREAM_MAX_DELAY=0.024
|
||||
STREAM_SHORT_TEXT_THRESHOLD=10
|
||||
STREAM_LONG_TEXT_THRESHOLD=50
|
||||
STREAM_CHUNK_SIZE=5
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
#### 基础配置
|
||||
|
||||
- `BASE_URL`: Gemini API 的基础 URL
|
||||
- 默认值: `https://generativelanguage.googleapis.com/v1beta`
|
||||
- 说明: 通常无需修改,除非 API 地址发生变化
|
||||
- `MAX_FAILURES`: API Key 允许的最大失败次数
|
||||
- 默认值: `3`
|
||||
- 说明: 超过此次数后,Key 将被暂时标记为无效
|
||||
|
||||
#### 认证与安全配置
|
||||
|
||||
- `API_KEYS`: Gemini API 密钥列表
|
||||
- 格式: JSON 数组字符串
|
||||
- 用途: 支持多个 Key 轮询,实现负载均衡
|
||||
- 建议: 至少配置 2 个 Key 以保证服务可用性
|
||||
- `ALLOWED_TOKENS`: 访问令牌列表
|
||||
- 格式: JSON 数组字符串
|
||||
- 用途: 用于客户端认证
|
||||
- 安全提示: 请使用足够复杂的令牌
|
||||
- `AUTH_TOKEN`: 超级管理员令牌
|
||||
- 可选配置,留空则使用 ALLOWED_TOKENS 的第一个
|
||||
- 具有查看 API Key 状态等特权操作权限
|
||||
|
||||
#### 模型功能配置
|
||||
|
||||
- `MODEL_SEARCH`: 搜索功能支持的模型
|
||||
- 默认值: `["gemini-2.0-flash-exp"]`
|
||||
- 说明: 仅列表中的模型可使用搜索功能
|
||||
- `TOOLS_CODE_EXECUTION_ENABLED`: 代码执行功能
|
||||
- 默认值: `false`
|
||||
- 安全提示: 生产环境建议禁用
|
||||
- `SHOW_SEARCH_LINK`: 搜索结果链接显示
|
||||
- 默认值: `true`
|
||||
- 用途: 控制搜索结果中是否包含原始链接
|
||||
- `SHOW_THINKING_PROCESS`: 思考过程显示
|
||||
- 默认值: `true`
|
||||
- 用途: 显示模型的推理过程,便于调试
|
||||
|
||||
#### 图片生成配置
|
||||
|
||||
- `PAID_KEY`: 付费版 API Key
|
||||
- 用途: 用于图片生成等高级功能
|
||||
- 说明: 需要单独申请的付费版 Key
|
||||
- `CREATE_IMAGE_MODEL`: 图片生成模型
|
||||
- 默认值: `imagen-3.0-generate-002`
|
||||
- 说明: 当前支持的最新图片生成模型
|
||||
|
||||
#### 图片上传配置
|
||||
|
||||
- `UPLOAD_PROVIDER`: 图片上传服务提供商
|
||||
- 默认值: `smms`
|
||||
- 可选值: `smms`, `picgo`, `cloudflare_imgbed`
|
||||
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
|
||||
|
||||
- `SMMS_SECRET_TOKEN`: SM.MS API Token
|
||||
- 用途: 用于图片上传到 SM.MS 图床的身份验证。
|
||||
- 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取。
|
||||
|
||||
- `PICGO_API_KEY`: PicGo API Key
|
||||
- 用途: 用于图片上传到 PicGo 图床的身份验证。
|
||||
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
|
||||
|
||||
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
|
||||
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
|
||||
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
|
||||
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
|
||||
|
||||
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
|
||||
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
|
||||
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权,则留空即可。
|
||||
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
|
||||
|
||||
#### 流式输出优化配置
|
||||
|
||||
- `STREAM_MIN_DELAY`: 最小延迟时间
|
||||
- 默认值: `0.016`(秒)
|
||||
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
|
||||
- `STREAM_MAX_DELAY`: 最大延迟时间
|
||||
- 默认值: `0.024`(秒)
|
||||
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
|
||||
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
|
||||
- 默认值: `10`(字符)
|
||||
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
|
||||
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
|
||||
- 默认值: `50`(字符)
|
||||
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
|
||||
- `STREAM_CHUNK_SIZE`: 长文本分块大小
|
||||
- 默认值: `5`(字符)
|
||||
- 说明: 长文本分块输出时,每个块的大小
|
||||
|
||||
### ▶️ 运行
|
||||
|
||||
#### 使用 Docker (推荐)
|
||||
#### a) dockerfile构建
|
||||
|
||||
1. **构建镜像**:
|
||||
|
||||
@@ -196,282 +89,178 @@
|
||||
docker run -d -p 8000:8000 --env-file .env gemini-balance
|
||||
```
|
||||
|
||||
- `-d`: 后台运行。
|
||||
- `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口。
|
||||
- `--env-file .env`: 使用 `.env` 文件设置环境变量。
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量。
|
||||
|
||||
#### 手动运行
|
||||
> 注意:如果使用 SQLite 数据库,需要挂载数据卷以持久化数据:
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
|
||||
> ```
|
||||
> 其中 `/path/to/data` 是主机上的数据存储路径,`/app/data` 是容器内的数据目录。
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
#### b) 用现有的docker镜像部署
|
||||
|
||||
- `--reload`: 开启热重载,方便开发调试 (生产环境不建议开启)。
|
||||
1. **拉取镜像**:
|
||||
|
||||
## 🔌 API 接口
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
### 认证
|
||||
2. **运行容器**:
|
||||
|
||||
所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer <your-token>`,其中 `<your-token>` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个或者 `AUTH_TOKEN`。
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
### API 路由
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口 (根据需要调整)。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量 (确保 `.env` 文件存在于执行命令的目录)。
|
||||
|
||||
本服务提供两种API路由:
|
||||
> 注意:如果使用 SQLite 数据库,需要挂载数据卷以持久化数据:
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
|
||||
> ```
|
||||
> 其中 `/path/to/data` 是主机上的数据存储路径,`/app/data` 是容器内的数据目录。
|
||||
|
||||
1. **OpenAI 兼容路由** (推荐)
|
||||
- 基础路径: `/v1`
|
||||
- 完全兼容OpenAI API格式
|
||||
- 支持所有Gemini模型
|
||||
### 本地运行 (适用于开发和测试)
|
||||
|
||||
2. **Gemini 原生路由**
|
||||
- 基础路径: `/gemini/v1beta` 或 `/v1beta`
|
||||
- 遵循Google原生API格式
|
||||
- 适用于需要直接使用Gemini API的场景
|
||||
如果您想在本地直接运行源代码进行开发或测试,请按照以下步骤操作:
|
||||
|
||||
### OpenAI兼容路由
|
||||
1. **确保已完成准备工作**:
|
||||
* 克隆仓库到本地。
|
||||
* 安装 Python 3.9 或更高版本。
|
||||
* 在项目根目录下创建并配置好 `.env` 文件 (参考前面的"配置环境变量"部分)。
|
||||
* 安装项目依赖:
|
||||
|
||||
#### 获取模型列表
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
- **URL**: `/v1/models`
|
||||
- **Method**: `GET`
|
||||
- **Header**: `Authorization: Bearer <your-token>`
|
||||
- **Response**: 返回支持的所有模型列表,包括最新的`gemini-2.0-flash-exp-search`等模型
|
||||
2. **启动应用**:
|
||||
在项目根目录下运行以下命令:
|
||||
|
||||
#### 聊天补全 (Chat Completions)
|
||||
|
||||
- **URL**: `/v1/chat/completions`
|
||||
- **Method**: `POST`
|
||||
- **Header**: `Authorization: Bearer <your-token>`
|
||||
- **Body** (JSON):
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
}
|
||||
],
|
||||
"model": "gemini-1.5-flash-002",
|
||||
"temperature": 0.7,
|
||||
"stream": false,
|
||||
"tools": [],
|
||||
"max_tokens": 8192,
|
||||
"stop": [],
|
||||
"top_p": 0.9,
|
||||
"top_k": 40
|
||||
}
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
- `messages`: 消息列表,格式与 OpenAI API 相同
|
||||
- `model`: 模型名称,支持所有Gemini模型,包括:
|
||||
- `gemini-1.5-flash-002`: 快速响应模型
|
||||
- `gemini-2.0-flash-exp`: 实验性快速响应模型
|
||||
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
|
||||
- `stream`: 是否开启流式响应,`true` 或 `false`
|
||||
- `tools`: 使用的工具列表
|
||||
- 其他参数:与 OpenAI API 兼容的参数,如 `temperature`, `max_tokens` 等
|
||||
* `app.main:app`: 指定 FastAPI 应用实例的位置 (`app` 模块中的 `main.py` 文件里的 `app` 对象)。
|
||||
* `--host 0.0.0.0`: 使应用可以从本地网络中的任何 IP 地址访问。
|
||||
* `--port 8000`: 指定应用监听的端口号 (您可以根据需要修改)。
|
||||
* `--reload`: 启用自动重载功能。当您修改代码时,服务会自动重启,非常适合开发环境 (生产环境请移除此选项)。
|
||||
|
||||
### Gemini原生路由
|
||||
3. **访问应用**:
|
||||
应用启动后,您可以通过浏览器或 API 工具访问 `http://localhost:8000` (或您指定的主机和端口)。
|
||||
|
||||
#### 获取模型列表
|
||||
### 完整配置项列表
|
||||
|
||||
- **URL**: `/gemini/v1beta/models` 或 `/v1beta/models`
|
||||
- **Method**: `GET`
|
||||
- **Header**: `Authorization: Bearer <your-token>`
|
||||
| 配置项 | 说明 | 默认值 |
|
||||
| :--------------------------- | :------------------------------------------------------- | :---------------------------------------------------- |
|
||||
| **数据库配置** | | |
|
||||
| `DATABASE_TYPE` | 可选,数据库类型,支持 `mysql` 或 `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | 可选,当使用 `sqlite` 时必填,SQLite 数据库文件路径 | `default_db` |
|
||||
| `MYSQL_HOST` | 当使用 `mysql` 时必填,MySQL 数据库主机地址 | `localhost` |
|
||||
| `MYSQL_SOCKET` | 可选,MySQL 数据库 socket 地址 | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | 当使用 `mysql` 时必填,MySQL 数据库端口 | `3306` |
|
||||
| `MYSQL_USER` | 当使用 `mysql` 时必填,MySQL 数据库用户名 | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | 当使用 `mysql` 时必填,MySQL 数据库密码 | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | 当使用 `mysql` 时必填,MySQL 数据库名称 | `defaultdb` |
|
||||
| **API 相关配置** | | |
|
||||
| `API_KEYS` | 必填,Gemini API 密钥列表,用于负载均衡 | `["your-gemini-api-key-1", "your-gemini-api-key-2"]` |
|
||||
| `ALLOWED_TOKENS` | 必填,允许访问的 Token 列表 | `["your-access-token-1", "your-access-token-2"]` |
|
||||
| `AUTH_TOKEN` | 可选,超级管理员token,具有所有权限,不填默认使用 ALLOWED_TOKENS 的第一个 | `sk-123456` |
|
||||
| `TEST_MODEL` | 可选,用于测试密钥是否可用的模型名 | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | 可选,支持绘图功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | 可选,支持搜索功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `FILTERED_MODELS` | 可选,被禁用的模型列表 | `["gemini-1.0-pro-vision-latest", ...]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | 可选,是否启用代码执行工具 | `false` |
|
||||
| `SHOW_SEARCH_LINK` | 可选,是否在响应中显示搜索结果链接 | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | 可选,是否显示模型思考过程 | `true` |
|
||||
| `THINKING_MODELS` | 可选,支持思考功能的模型列表 | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | 可选,思考功能预算映射 (模型名:预算值) | `{}` |
|
||||
| `BASE_URL` | 可选,Gemini API 基础 URL,默认无需修改 | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | 可选,允许单个key失败的次数 | `3` |
|
||||
| `MAX_RETRIES` | 可选,API 请求失败时的最大重试次数 | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | 可选,检查禁用 Key 是否恢复的时间间隔 (小时) | `1` |
|
||||
| `TIMEZONE` | 可选,应用程序使用的时区 | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | 可选,请求超时时间 (秒) | `300` |
|
||||
| `PROXIES` | 可选,代理服务器列表 (例如 `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
|
||||
| `LOG_LEVEL` | 可选,日志级别,例如 DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 可选,是否开启自动删除错误日志 | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 可选,自动删除多少天前的错误日志 (例如 1, 7, 30) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 可选,是否开启自动删除请求日志 | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 可选,自动删除多少天前的请求日志 (例如 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | 可选,安全设置 (JSON 字符串格式),用于配置内容安全阈值。示例值可能需要根据实际模型支持情况调整。 | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 可选,付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 可选,图片生成模型 | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | 可选,图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `SMMS_SECRET_TOKEN` | 可选,SM.MS图床的API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | 可选,[PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | 可选,[CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| 可选,CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| **流式优化器相关** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | 可选,是否启用流式输出优化 | `false` |
|
||||
| `STREAM_MIN_DELAY` | 可选,流式输出最小延迟 | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | 可选,流式输出最大延迟 | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD`| 可选,短文本阈值 | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | 可选,长文本阈值 | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | 可选,流式输出块大小 | `5` |
|
||||
| **伪流式 (Fake Stream) 相关** | | |
|
||||
| `FAKE_STREAM_ENABLED` | 可选,是否启用伪流式传输,用于不支持流式的模型或场景 | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | 可选,伪流式传输时发送心跳空数据的间隔秒数 | `5` |
|
||||
|
||||
#### 生成内容
|
||||
## ⚙️ API 端点
|
||||
|
||||
- **URL**: `/gemini/v1beta/models/{model_name}:generateContent`
|
||||
- **Method**: `POST`
|
||||
- **Header**: `Authorization: Bearer <your-token>`
|
||||
以下是服务提供的主要 API 端点:
|
||||
|
||||
#### 流式生成内容
|
||||
### Gemini API 相关 (`(/gemini)/v1beta`)
|
||||
|
||||
- **URL**: `/gemini/v1beta/models/{model_name}:streamGenerateContent`
|
||||
- **Method**: `POST`
|
||||
- **Header**: `Authorization: Bearer <your-token>`
|
||||
* `GET /models`: 列出可用的 Gemini 模型。
|
||||
* `POST /models/{model_name}:generateContent`: 使用指定的 Gemini 模型生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 使用指定的 Gemini 模型流式生成内容。
|
||||
|
||||
### 获取词向量 (Embeddings)
|
||||
### OpenAI API 相关
|
||||
|
||||
- **URL**: `/v1/embeddings`
|
||||
- **Method**: `POST`
|
||||
- **Header**: `Authorization: Bearer <your-token>`
|
||||
- **Body** (JSON):
|
||||
|
||||
```json
|
||||
{
|
||||
"input": "你的文本",
|
||||
"model": "text-embedding-004"
|
||||
}
|
||||
```
|
||||
|
||||
- `input`: 输入文本。
|
||||
- `model`: 模型名称。
|
||||
|
||||
### 健康检查
|
||||
|
||||
- **URL**: `/health`
|
||||
- **Method**: `GET`
|
||||
|
||||
### Web界面功能
|
||||
|
||||
#### 验证页面 (auth.html)
|
||||
|
||||
- **URL**: `/auth`
|
||||
- **说明**: 提供了一个简洁的Web界面用于验证访问令牌
|
||||
- **功能特点**:
|
||||
- 现代化的渐变背景设计
|
||||
- 响应式布局,完美支持移动端
|
||||
- 毛玻璃效果的卡片设计
|
||||
- 优雅的动画效果(淡入、滑动、悬浮)
|
||||
- 安全的令牌验证机制
|
||||
- 清晰的错误提示功能
|
||||
- PWA支持,可安装为本地应用
|
||||
- 底部版权信息和GitHub链接
|
||||
- 支持暗色主题适配
|
||||
|
||||
#### API密钥状态管理 (keys_status.html)
|
||||
|
||||
- **URL**: `/v1/keys/list`
|
||||
- **Method**: `GET`
|
||||
- **Header**: `Authorization: Bearer <your-auth-token>`
|
||||
- **功能特点**:
|
||||
- 只有使用 `AUTH_TOKEN` 才能访问此接口
|
||||
- 分类展示API密钥状态(有效/无效)
|
||||
- 可折叠的密钥列表分组
|
||||
- 每个密钥显示:
|
||||
- 状态标识(有效/无效)
|
||||
- 密钥内容
|
||||
- 失败次数统计
|
||||
- 高级功能:
|
||||
- 一键复制单个密钥
|
||||
- 批量复制分组密钥(JSON格式)
|
||||
- 实时刷新功能
|
||||
- 回到顶部/底部快捷按钮
|
||||
- 界面特性:
|
||||
- 响应式设计,适配各种屏幕
|
||||
- 优雅的动画效果
|
||||
- 操作反馈(复制成功提示)
|
||||
- PWA支持
|
||||
- 暗色主题适配
|
||||
|
||||
### 图片生成 (Image Generation)
|
||||
|
||||
- **URL**: `/v1/images/generations`
|
||||
- **Method**: `POST`
|
||||
- **Header**: `Authorization: Bearer <your-auth-token>`
|
||||
- **说明**: Body示例和参数说明
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "{n:2} {ratio:16:9} 汉服美女",
|
||||
"n": 1,
|
||||
"size": "1024x1024"
|
||||
}
|
||||
```
|
||||
|
||||
**Prompt参数说明:**
|
||||
|
||||
prompt支持通过特殊标记来控制生成参数:
|
||||
|
||||
1. 图片数量控制:
|
||||
- 格式: `{n:数量}`
|
||||
- 示例: `{n:2} 一只可爱的猫` - 生成2张图片
|
||||
- 取值范围: 1-4
|
||||
- 说明: 如果在prompt中指定了n,将覆盖请求body中的n参数
|
||||
|
||||
2. 图片比例控制:
|
||||
- 格式: `{ratio:宽:高}`
|
||||
- 示例: `{ratio:16:9} 一片森林` - 生成16:9比例的图片
|
||||
- 支持的比例: "1:1"、"3:4"、"4:3"、"9:16"、"16:9"
|
||||
- 说明: 如果指定了size参数,将优先使用size对应的比例
|
||||
|
||||
3. 参数组合:
|
||||
- 示例: `{n:2} {ratio:16:9} 一片美丽的森林` - 生成2张16:9比例的图片
|
||||
- 说明: 这些参数标记会自动从prompt中移除,不会影响实际的图片生成提示词
|
||||
|
||||
> 注意:n的取值范围[1,4], ratio取值范围"1:1"、"3:4"、"4:3"、"9:16" 和 "16:9"
|
||||
|
||||
## 📚 代码结构
|
||||
|
||||
```plaintext
|
||||
.
|
||||
├── app/
|
||||
│ ├── api/ # API 路由
|
||||
│ │ ├── gemini_routes.py # Gemini 模型路由
|
||||
│ │ └── openai_routes.py # OpenAI 兼容路由
|
||||
│ ├── core/ # 核心组件
|
||||
│ │ ├── config.py # 配置管理
|
||||
│ │ ├── logger.py # 日志配置
|
||||
│ │ └── security.py # 安全认证
|
||||
│ ├── middleware/ # 中间件
|
||||
│ │ └── request_logging_middleware.py # 请求日志中间件
|
||||
│ ├── schemas/ # 数据模型
|
||||
│ │ ├── gemini_models.py # Gemini 原始请求/响应模型
|
||||
│ │ └── openai_models.py # OpenAI 兼容请求/响应模型
|
||||
│ ├── services/ # 服务层
|
||||
│ │ ├── chat/ # 聊天相关服务
|
||||
│ │ │ ├── api_client.py # API 客户端
|
||||
│ │ │ ├── message_converter.py # 消息转换器
|
||||
│ │ │ ├── response_handler.py # 响应处理器
|
||||
│ │ │ └── retry_handler.py #重试处理器
|
||||
│ │ ├── gemini_chat_service.py # Gemini 原始聊天服务
|
||||
│ │ ├── openai_chat_service.py # OpenAI 兼容聊天服务
|
||||
│ │ ├── embedding_service.py # 向量服务
|
||||
│ │ ├── key_manager.py # API Key 管理
|
||||
│ │ └── model_service.py # 模型服务
|
||||
│ └── main.py # 主程序入口
|
||||
├── Dockerfile # Dockerfile
|
||||
├── requirements.txt # 项目依赖
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 🔒 安全性
|
||||
|
||||
- **API Key 轮询**: 自动轮换 API Key,提高可用性和负载均衡。
|
||||
- **Bearer Token 认证**: 保护 API 端点,防止未经授权的访问。
|
||||
- **请求日志记录**: 记录详细的请求信息,便于调试和审计 (可选,通过取消 `app.add_middleware(RequestLoggingMiddleware)` 的注释来启用)。
|
||||
- **自动重试**: 在 API 请求失败时自动重试,提高服务的稳定性。
|
||||
* `GET (/hf)/v1/models`: 列出可用的模型 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/chat/completions`: 进行聊天补全 (底层用的gemini格式, 支持流式传输)。
|
||||
* `POST (/hf)/v1/embeddings`: 创建文本嵌入 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/images/generations`: 生成图像 (底层用的gemini格式)。
|
||||
* `GET /openai/v1/models`: 列出可用的模型 (底层用的openai格式)。
|
||||
* `POST /openai/v1/chat/completions`: 进行聊天补全 (底层用的openai格式, 支持流式传输, 可防止截断,速度也快)。
|
||||
* `POST /openai/v1/embeddings`: 创建文本嵌入 (底层用的openai格式)。
|
||||
* `POST /openai/v1/images/generations`: 生成图像 (底层用的openai格式)。
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
欢迎任何形式的贡献!如果你发现 bug、有新功能建议或者想改进代码,请随时提交 Issue 或 Pull Request。
|
||||
欢迎提交 Pull Request 或 Issue。
|
||||
|
||||
1. Fork 本项目。
|
||||
2. 创建你的特性分支 (`git checkout -b feature/AmazingFeature`)。
|
||||
3. 提交你的改动 (`git commit -m 'Add some AmazingFeature'`)。
|
||||
4. 推送到你的分支 (`git push origin feature/AmazingFeature`)。
|
||||
5. 创建一个新的 Pull Request。
|
||||
## 🎉 特别鸣谢
|
||||
|
||||
## ❓ 常见问题解答 (FAQ)
|
||||
特别鸣谢以下项目和平台为本项目提供图床服务:
|
||||
|
||||
**Q: 如何获取 Gemini API Key?**
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) 开源项目
|
||||
|
||||
A: 请参考 Gemini API 的官方文档,申请 API Key。
|
||||
## 🙏 感谢贡献者
|
||||
|
||||
**Q: 如何配置多个 API Key?**
|
||||
感谢所有为本项目做出贡献的开发者!
|
||||
|
||||
A: 在 `.env` 文件的 `API_KEYS` 变量中,用列表的形式添加多个 Key,例如:`API_KEYS=["key1", "key2", "key3"]`。
|
||||
[](https://github.com/snailyp/gemini-balance/graphs/contributors)
|
||||
|
||||
**Q: 为什么我的 API Key 总是失败?**
|
||||
## ⭐ Star History
|
||||
|
||||
A: 请检查以下几点:
|
||||
[](https://star-history.com/#snailyp/gemini-balance&Date)
|
||||
|
||||
- API Key 是否正确。
|
||||
- API Key 是否已过期或被禁用。
|
||||
- 是否超出了 API Key 的速率限制或配额。
|
||||
- 网络连接是否正常。
|
||||
## 💖 友情项目
|
||||
|
||||
**Q: 如何启用流式响应?**
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine一线:AI驱动的热点事件时间轴生成工具
|
||||
|
||||
A: 在请求的 Body 中,将 `stream` 参数设置为 `true` 即可。
|
||||
## 🎁 项目支持
|
||||
|
||||
**Q: 如何启用代码执行工具?**
|
||||
如果你觉得这个项目对你有帮助,可以考虑通过 [爱发电](https://afdian.com/a/snaily) 支持我。
|
||||
|
||||
A: 在 `.env` 文件的 `TOOLS_CODE_EXECUTION_ENABLED` 变量中, 设置为 `true` 即可。
|
||||
## 许可证
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目采用 MIT 许可证。有关详细信息,请参阅 [LICENSE](LICENSE) 文件 (你需要创建一个 LICENSE 文件)。
|
||||
本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.schemas.gemini_models import GeminiContent, GeminiRequest
|
||||
from app.services.gemini_chat_service import GeminiChatService
|
||||
from app.services.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.services.model_service import ModelService
|
||||
from app.services.chat.retry_handler import RetryHandler
|
||||
|
||||
router = APIRouter(prefix="/gemini/v1beta")
|
||||
router_v1beta = APIRouter(prefix="/v1beta")
|
||||
logger = get_gemini_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@router_v1beta.get("/models")
|
||||
async def list_models(_=Depends(security_service.verify_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取可用的Gemini模型列表"""
|
||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||
logger.info("Handling Gemini models list request")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
models_json = model_service.get_gemini_models(api_key)
|
||||
|
||||
# 模型名称以及对应的详细信息
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||
|
||||
# 添加搜索模型
|
||||
if settings.MODEL_SEARCH:
|
||||
for name in settings.MODEL_SEARCH:
|
||||
model = model_mapping.get(name, None)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-search"
|
||||
display_name = f'{item.get("displayName")} For Search'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加图像生成模型
|
||||
if settings.MODEL_IMAGE:
|
||||
for name in settings.MODEL_IMAGE:
|
||||
model = model_mapping.get(name, None)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-image"
|
||||
display_name = f'{item.get("displayName")} For Image'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
return models_json
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
@router_v1beta.post("/models/{model_name}:generateContent")
|
||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||
"""非流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||
"""流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/verify-key/{api_key}")
|
||||
async def verify_key(api_key: str):
|
||||
key_manager = await get_key_manager()
|
||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||
"""验证Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||||
logger.info("Verifying API key validity")
|
||||
|
||||
try:
|
||||
# 使用generate_content接口测试key的有效性
|
||||
gemini_requset = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=[{"text": "hi"}]
|
||||
)
|
||||
]
|
||||
)
|
||||
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
|
||||
if response:
|
||||
return JSONResponse({"status": "valid"})
|
||||
return JSONResponse({"status": "invalid"})
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
return JSONResponse({"status": "invalid", "error": str(e)})
|
||||
@@ -1,145 +0,0 @@
|
||||
from fastapi import HTTPException, APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_openai_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
||||
from app.services.chat.retry_handler import RetryHandler
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
from app.services.image_create_service import ImageCreateService
|
||||
from app.services.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.services.model_service import ModelService
|
||||
from app.services.openai_chat_service import OpenAIChatService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
|
||||
embedding_service = EmbeddingService(settings.BASE_URL)
|
||||
image_create_service = ImageCreateService()
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
logger.info("-" * 50 + "list_models" + "-" * 50)
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
return model_service.get_gemini_openai_models(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting models list: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/hf/v1/chat/completions")
|
||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
api_key = await key_manager.get_paid_key()
|
||||
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
|
||||
logger.info("-" * 50 + "chat_completion" + "-" * 50)
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(request.model):
|
||||
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
|
||||
|
||||
try:
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
response = await chat_service.create_image_chat_completion(request=request)
|
||||
else:
|
||||
response = await chat_service.create_chat_completion(request, api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
logger.info("Chat completion request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@router.post("/hf/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info("-" * 50 + "generate_image" + "-" * 50)
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
|
||||
try:
|
||||
response = image_create_service.generate_images(request)
|
||||
logger.info("Image generation request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Image generation request failed") from e
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
logger.info("-" * 50 + "embedding" + "-" * 50)
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
logger.info("Embedding request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
||||
|
||||
@router.get("/v1/keys/list")
|
||||
@router.get("/hf/v1/keys/list")
|
||||
async def get_keys_list(
|
||||
_=Depends(security_service.verify_auth_token),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""获取有效和无效的API key列表"""
|
||||
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
||||
logger.info("Handling keys list request")
|
||||
try:
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
return {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"]
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching keys list"
|
||||
) from e
|
||||
470
app/config/config.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
应用程序配置模块
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from pydantic import ValidationError, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from sqlalchemy import insert, select, update
|
||||
|
||||
from app.core.constants import (
|
||||
API_VERSION,
|
||||
DEFAULT_CREATE_IMAGE_MODEL,
|
||||
DEFAULT_FILTER_MODELS,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_SAFETY_SETTINGS,
|
||||
DEFAULT_STREAM_CHUNK_SIZE,
|
||||
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
DEFAULT_STREAM_MAX_DELAY,
|
||||
DEFAULT_STREAM_MIN_DELAY,
|
||||
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
DEFAULT_TIMEOUT,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
from app.log.logger import Logger
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 数据库配置
|
||||
DATABASE_TYPE: str = "mysql" # sqlite 或 mysql
|
||||
SQLITE_DATABASE: str = "default_db"
|
||||
MYSQL_HOST: str = ""
|
||||
MYSQL_PORT: int = 3306
|
||||
MYSQL_USER: str = ""
|
||||
MYSQL_PASSWORD: str = ""
|
||||
MYSQL_DATABASE: str = ""
|
||||
MYSQL_SOCKET: str = ""
|
||||
|
||||
# 验证 MySQL 配置
|
||||
@field_validator(
|
||||
"MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE"
|
||||
)
|
||||
def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any:
|
||||
if info.data.get("DATABASE_TYPE") == "mysql":
|
||||
if v is None or v == "":
|
||||
raise ValueError(
|
||||
"MySQL configuration is required when DATABASE_TYPE is 'mysql'"
|
||||
)
|
||||
return v
|
||||
|
||||
# API相关配置
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
TEST_MODEL: str = DEFAULT_MODEL
|
||||
TIME_OUT: int = DEFAULT_TIMEOUT
|
||||
MAX_RETRIES: int = MAX_RETRIES
|
||||
PROXIES: List[str] = [] # 新增:代理服务器列表
|
||||
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
THINKING_MODELS: List[str] = [] # 新增:用于思考过程的模型列表
|
||||
THINKING_BUDGET_MAP: Dict[str, float] = {} # 新增:模型对应的预算映射
|
||||
|
||||
# 图像生成相关配置
|
||||
PAID_KEY: str = ""
|
||||
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
||||
UPLOAD_PROVIDER: str = "smms"
|
||||
SMMS_SECRET_TOKEN: str = ""
|
||||
PICGO_API_KEY: str = ""
|
||||
CLOUDFLARE_IMGBED_URL: str = ""
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_OPTIMIZER_ENABLED: bool = False
|
||||
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
|
||||
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
|
||||
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
|
||||
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
|
||||
|
||||
# 假流式配置 (Fake Streaming Configuration)
|
||||
FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出
|
||||
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒)
|
||||
|
||||
# 调度器配置
|
||||
CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时
|
||||
TIMEZONE: str = "Asia/Shanghai" # 默认时区
|
||||
|
||||
# github
|
||||
GITHUB_REPO_OWNER: str = "snailyp"
|
||||
GITHUB_REPO_NAME: str = "gemini-balance"
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = "INFO" # 默认日志级别
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True # 是否开启自动删除错误日志
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7 # 自动删除多少天前的错误日志 (1, 7, 30)
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False # 是否开启自动删除请求日志
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30 # 自动删除多少天前的请求日志 (1, 7, 30)
|
||||
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS # 新增:安全设置
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# 设置默认AUTH_TOKEN(如果未提供)
|
||||
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
|
||||
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
"""尝试将数据库字符串值解析为目标 Python 类型"""
|
||||
from app.log.logger import get_config_logger # 函数内导入
|
||||
|
||||
logger = get_config_logger() # 函数内初始化
|
||||
try:
|
||||
# 处理 List[str]
|
||||
if target_type == List[str]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
# 处理 Dict[str, float]
|
||||
elif target_type == Dict[str, float]:
|
||||
parsed_dict = {}
|
||||
try:
|
||||
# First attempt: standard JSON parsing
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
# Second attempt: try replacing single quotes if JSONDecodeError occurred
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(
|
||||
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
||||
)
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
||||
)
|
||||
else:
|
||||
# Log other errors (ValueError, TypeError) or JSON errors without single quotes
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
||||
)
|
||||
return parsed_dict # Return the parsed dict or an empty one if all attempts fail
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif target_type == List[Dict[str, str]]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
|
||||
valid = all(
|
||||
isinstance(item, dict)
|
||||
and all(isinstance(k, str) for k in item.keys())
|
||||
and all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
||||
)
|
||||
return [] # 或者返回默认值?这里返回空列表
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
||||
)
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
||||
)
|
||||
return []
|
||||
# 处理 bool
|
||||
elif target_type == bool:
|
||||
return db_value.lower() in ("true", "1", "yes", "on")
|
||||
# 处理 int
|
||||
elif target_type == int:
|
||||
return int(db_value)
|
||||
# 处理 float
|
||||
elif target_type == float:
|
||||
return float(db_value)
|
||||
# 默认为 str 或其他 pydantic 能直接处理的类型
|
||||
else:
|
||||
return db_value
|
||||
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value."
|
||||
)
|
||||
return db_value # 解析失败则返回原始字符串
|
||||
|
||||
|
||||
async def sync_initial_settings():
|
||||
"""
|
||||
应用启动时同步配置:
|
||||
1. 从数据库加载设置。
|
||||
2. 将数据库设置合并到内存 settings (数据库优先)。
|
||||
3. 将最终的内存 settings 同步回数据库。
|
||||
"""
|
||||
from app.log.logger import get_config_logger # 函数内导入
|
||||
|
||||
logger = get_config_logger() # 函数内初始化
|
||||
# 延迟导入以避免循环依赖和确保数据库连接已初始化
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings as SettingsModel
|
||||
|
||||
global settings
|
||||
logger.info("Starting initial settings synchronization...")
|
||||
|
||||
if not database.is_connected:
|
||||
try:
|
||||
await database.connect()
|
||||
logger.info("Database connection established for initial sync.")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to connect to database for initial settings sync: {e}. Skipping sync."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. 从数据库加载设置
|
||||
db_settings_raw: List[Dict[str, Any]] = []
|
||||
try:
|
||||
query = select(SettingsModel.key, SettingsModel.value)
|
||||
results = await database.fetch_all(query)
|
||||
db_settings_raw = [
|
||||
{"key": row["key"], "value": row["value"]} for row in results
|
||||
]
|
||||
logger.info(f"Fetched {len(db_settings_raw)} settings from database.")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings."
|
||||
)
|
||||
# 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库
|
||||
|
||||
db_settings_map: Dict[str, str] = {
|
||||
s["key"]: s["value"] for s in db_settings_raw
|
||||
}
|
||||
|
||||
# 2. 将数据库设置合并到内存 settings (数据库优先)
|
||||
updated_in_memory = False
|
||||
|
||||
for key, db_value in db_settings_map.items():
|
||||
if key == "DATABASE_TYPE":
|
||||
logger.debug(
|
||||
f"Skipping update of '{key}' in memory from database. "
|
||||
"This setting is controlled by environment/dotenv."
|
||||
)
|
||||
continue
|
||||
if hasattr(settings, key):
|
||||
target_type = Settings.__annotations__.get(key)
|
||||
if target_type:
|
||||
try:
|
||||
parsed_db_value = _parse_db_value(key, db_value, target_type)
|
||||
memory_value = getattr(settings, key)
|
||||
|
||||
# 比较解析后的值和内存中的值
|
||||
# 注意:对于列表等复杂类型,直接比较可能不够健壮,但这里简化处理
|
||||
if parsed_db_value != memory_value:
|
||||
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
|
||||
type_match = False
|
||||
if target_type == List[str] and isinstance(
|
||||
parsed_db_value, list
|
||||
):
|
||||
type_match = True
|
||||
elif target_type == Dict[str, float] and isinstance(
|
||||
parsed_db_value, dict
|
||||
):
|
||||
type_match = True
|
||||
elif target_type not in (
|
||||
List[str],
|
||||
Dict[str, float],
|
||||
) and isinstance(parsed_db_value, target_type):
|
||||
type_match = True
|
||||
|
||||
if type_match:
|
||||
setattr(settings, key, parsed_db_value)
|
||||
logger.debug(
|
||||
f"Updated setting '{key}' in memory from database value ({target_type})."
|
||||
)
|
||||
updated_in_memory = True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing database setting for key '{key}': {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Database setting '{key}' not found in Settings model definition. Ignoring."
|
||||
)
|
||||
|
||||
# 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐)
|
||||
if updated_in_memory:
|
||||
try:
|
||||
# 重新加载以确保类型转换和验证
|
||||
settings = Settings(**settings.model_dump())
|
||||
logger.info(
|
||||
"Settings object re-validated after merging database values."
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.error(
|
||||
f"Validation error after merging database settings: {e}. Settings might be inconsistent."
|
||||
)
|
||||
|
||||
# 3. 将最终的内存 settings 同步回数据库
|
||||
final_memory_settings = settings.model_dump()
|
||||
settings_to_update: List[Dict[str, Any]] = []
|
||||
settings_to_insert: List[Dict[str, Any]] = []
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
existing_db_keys = set(db_settings_map.keys())
|
||||
|
||||
for key, value in final_memory_settings.items():
|
||||
if key == "DATABASE_TYPE":
|
||||
logger.debug(
|
||||
f"Skipping synchronization of '{key}' to database. "
|
||||
"This setting is controlled by environment/dotenv."
|
||||
)
|
||||
continue
|
||||
|
||||
# 序列化值为字符串或 JSON 字符串
|
||||
if isinstance(value, (list, dict)): # 处理列表和字典
|
||||
db_value = json.dumps(
|
||||
value, ensure_ascii=False
|
||||
) # 使用 ensure_ascii=False 以支持非 ASCII 字符
|
||||
elif isinstance(value, bool):
|
||||
db_value = str(value).lower()
|
||||
elif value is None: # 处理 None 值
|
||||
db_value = "" # 或者根据需要设为 NULL 或其他标记
|
||||
else:
|
||||
db_value = str(value)
|
||||
|
||||
data = {
|
||||
"key": key,
|
||||
"value": db_value,
|
||||
"description": f"{key} configuration setting", # 默认描述
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
if key in existing_db_keys:
|
||||
# 仅当值与数据库中的不同时才更新
|
||||
if db_settings_map[key] != db_value:
|
||||
settings_to_update.append(data)
|
||||
else:
|
||||
# 如果键不在数据库中,则插入
|
||||
data["created_at"] = now
|
||||
settings_to_insert.append(data)
|
||||
|
||||
# 在事务中执行批量插入和更新
|
||||
if settings_to_insert or settings_to_update:
|
||||
try:
|
||||
async with database.transaction():
|
||||
if settings_to_insert:
|
||||
# 获取现有描述以避免覆盖
|
||||
query_existing = select(
|
||||
SettingsModel.key, SettingsModel.description
|
||||
).where(
|
||||
SettingsModel.key.in_(
|
||||
[s["key"] for s in settings_to_insert]
|
||||
)
|
||||
)
|
||||
existing_desc = {
|
||||
row["key"]: row["description"]
|
||||
for row in await database.fetch_all(query_existing)
|
||||
}
|
||||
for item in settings_to_insert:
|
||||
item["description"] = existing_desc.get(
|
||||
item["key"], item["description"]
|
||||
)
|
||||
|
||||
query_insert = insert(SettingsModel).values(settings_to_insert)
|
||||
await database.execute(query=query_insert)
|
||||
logger.info(
|
||||
f"Synced (inserted) {len(settings_to_insert)} settings to database."
|
||||
)
|
||||
|
||||
if settings_to_update:
|
||||
# 获取现有描述以避免覆盖
|
||||
query_existing = select(
|
||||
SettingsModel.key, SettingsModel.description
|
||||
).where(
|
||||
SettingsModel.key.in_(
|
||||
[s["key"] for s in settings_to_update]
|
||||
)
|
||||
)
|
||||
existing_desc = {
|
||||
row["key"]: row["description"]
|
||||
for row in await database.fetch_all(query_existing)
|
||||
}
|
||||
|
||||
for setting_data in settings_to_update:
|
||||
setting_data["description"] = existing_desc.get(
|
||||
setting_data["key"], setting_data["description"]
|
||||
)
|
||||
query_update = (
|
||||
update(SettingsModel)
|
||||
.where(SettingsModel.key == setting_data["key"])
|
||||
.values(
|
||||
value=setting_data["value"],
|
||||
description=setting_data["description"],
|
||||
updated_at=setting_data["updated_at"],
|
||||
)
|
||||
)
|
||||
await database.execute(query=query_update)
|
||||
logger.info(
|
||||
f"Synced (updated) {len(settings_to_update)} settings to database."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to sync settings to database during startup: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No setting changes detected between memory and database during initial sync."
|
||||
)
|
||||
|
||||
# 刷新日志等级
|
||||
Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred during initial settings sync: {e}")
|
||||
finally:
|
||||
if database.is_connected:
|
||||
try:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting database after initial sync: {e}")
|
||||
|
||||
logger.info("Initial settings synchronization finished.")
|
||||
165
app/core/application.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path # Add pathlib import
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.config.config import settings, sync_initial_settings
|
||||
from app.log.logger import get_application_logger
|
||||
from app.middleware.middleware import setup_middlewares
|
||||
from app.exception.exceptions import setup_exception_handlers
|
||||
from app.router.routes import setup_routers
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.database.connection import connect_to_db, disconnect_from_db
|
||||
from app.utils.helpers import get_current_version # Import from helpers
|
||||
from app.database.initialization import initialize_database
|
||||
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
||||
from app.service.update.update_service import check_for_updates
|
||||
|
||||
logger = get_application_logger()
|
||||
|
||||
# Define project paths using pathlib
|
||||
# Assuming this file is at app/core/application.py
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
# VERSION_FILE_PATH = PROJECT_ROOT / "VERSION" # Removed: Defined in helpers.py
|
||||
STATIC_DIR = PROJECT_ROOT / "app" / "static"
|
||||
TEMPLATES_DIR = PROJECT_ROOT / "app" / "templates"
|
||||
|
||||
# Removed _get_current_version function definition, moved to helpers.py
|
||||
|
||||
# 初始化模板引擎,并添加全局变量
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
# 定义一个函数来更新模板全局变量
|
||||
def update_template_globals(app: FastAPI, update_info: dict):
|
||||
# Jinja2Templates 实例没有直接更新全局变量的方法
|
||||
# 我们需要在请求上下文中传递这些变量,或者修改 Jinja 环境
|
||||
# 更简单的方法是将其存储在 app.state 中,并在渲染时传递
|
||||
app.state.update_info = update_info
|
||||
logger.info(f"Update info stored in app.state: {update_info}")
|
||||
|
||||
|
||||
# --- Helper functions for lifespan ---
|
||||
|
||||
async def _setup_database_and_config(app_settings):
|
||||
"""Initializes database, syncs settings, and initializes KeyManager."""
|
||||
initialize_database()
|
||||
logger.info("Database initialized successfully")
|
||||
await connect_to_db()
|
||||
await sync_initial_settings()
|
||||
# Initialize KeyManager using potentially updated settings
|
||||
await get_key_manager_instance(app_settings.API_KEYS)
|
||||
logger.info("Database, config sync, and KeyManager initialized successfully")
|
||||
|
||||
async def _shutdown_database():
|
||||
"""Disconnects from the database."""
|
||||
await disconnect_from_db()
|
||||
|
||||
def _start_scheduler():
|
||||
"""Starts the background scheduler."""
|
||||
try:
|
||||
start_scheduler()
|
||||
logger.info("Scheduler started successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
|
||||
def _stop_scheduler():
|
||||
"""Stops the background scheduler."""
|
||||
stop_scheduler()
|
||||
|
||||
async def _perform_update_check(app: FastAPI):
|
||||
"""Checks for updates and stores the info in app.state."""
|
||||
update_available, latest_version, error_message = await check_for_updates()
|
||||
current_version = get_current_version() # Use imported function
|
||||
update_info = {
|
||||
"update_available": update_available,
|
||||
"latest_version": latest_version,
|
||||
"error_message": error_message,
|
||||
"current_version": current_version
|
||||
}
|
||||
# Ensure app.state exists and store update info
|
||||
if not hasattr(app, "state"):
|
||||
from starlette.datastructures import State
|
||||
app.state = State()
|
||||
app.state.update_info = update_info
|
||||
logger.info(f"Update check completed. Info: {update_info}")
|
||||
|
||||
# --- Application Lifespan ---
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Manages the application startup and shutdown events.
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
# Startup events
|
||||
logger.info("Application starting up...")
|
||||
try:
|
||||
# Setup database, config, and KeyManager
|
||||
await _setup_database_and_config(settings) # Pass settings object
|
||||
|
||||
# Perform update check after core components are ready
|
||||
# await _perform_update_check(app) # Removed: Version check moved to frontend API call
|
||||
|
||||
# Start the scheduler
|
||||
_start_scheduler()
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(f"Critical error during application startup: {str(e)}", exc_info=True)
|
||||
# Depending on the severity, you might want to prevent the app from fully starting
|
||||
# For now, we log critically and let it yield, potentially in a broken state.
|
||||
# Consider adding more robust error handling here if startup failures should halt the app.
|
||||
|
||||
yield # Application runs
|
||||
|
||||
# Shutdown events
|
||||
logger.info("Application shutting down...")
|
||||
_stop_scheduler()
|
||||
await _shutdown_database()
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""
|
||||
创建并配置FastAPI应用程序实例
|
||||
|
||||
Returns:
|
||||
FastAPI: 配置好的FastAPI应用程序实例
|
||||
"""
|
||||
# Removed: initialize_app() call
|
||||
|
||||
# 创建FastAPI应用
|
||||
# Read version from file for consistency
|
||||
current_version = get_current_version() # Use imported function
|
||||
app = FastAPI(
|
||||
title="Gemini Balance API",
|
||||
description="Gemini API代理服务,支持负载均衡和密钥管理",
|
||||
version=current_version,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Initialize app.state early to ensure it exists before lifespan potentially uses it
|
||||
if not hasattr(app, "state"):
|
||||
from starlette.datastructures import State
|
||||
app.state = State()
|
||||
# Set a default/initial state for update_info
|
||||
app.state.update_info = {
|
||||
"update_available": False,
|
||||
"latest_version": None,
|
||||
"error_message": "Initializing...",
|
||||
"current_version": current_version # Use version read earlier
|
||||
}
|
||||
|
||||
# 配置静态文件
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
|
||||
# 配置中间件
|
||||
setup_middlewares(app)
|
||||
|
||||
# 配置异常处理器
|
||||
setup_exception_handlers(app)
|
||||
|
||||
# 配置路由
|
||||
setup_routers(app)
|
||||
|
||||
return app
|
||||
@@ -1,41 +0,0 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import List
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
||||
MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
PAID_KEY: str = ""
|
||||
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
|
||||
UPLOAD_PROVIDER: str = "smms"
|
||||
SMMS_SECRET_TOKEN: str = ""
|
||||
PICGO_API_KEY: str = ""
|
||||
CLOUDFLARE_IMGBED_URL: str = ""
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
TEST_MODEL: str = "gemini-1.5-flash"
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_MIN_DELAY: float = 0.016
|
||||
STREAM_MAX_DELAY: float = 0.024
|
||||
STREAM_SHORT_TEXT_THRESHOLD: int = 10
|
||||
STREAM_LONG_TEXT_THRESHOLD: int = 50
|
||||
STREAM_CHUNK_SIZE: int = 5
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not self.AUTH_TOKEN:
|
||||
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
79
app/core/constants.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
常量定义模块
|
||||
"""
|
||||
|
||||
# API相关常量
|
||||
API_VERSION = "v1beta"
|
||||
DEFAULT_TIMEOUT = 300 # 秒
|
||||
MAX_RETRIES = 3 # 最大重试次数
|
||||
|
||||
# 模型相关常量
|
||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||
DEFAULT_MODEL = "gemini-1.5-flash"
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TOP_K = 40
|
||||
DEFAULT_FILTER_MODELS = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
|
||||
# 图像生成相关常量
|
||||
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||
|
||||
# 上传提供商
|
||||
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
||||
DEFAULT_UPLOAD_PROVIDER = "smms"
|
||||
|
||||
# 流式输出相关常量
|
||||
DEFAULT_STREAM_MIN_DELAY = 0.016
|
||||
DEFAULT_STREAM_MAX_DELAY = 0.024
|
||||
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
|
||||
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||
|
||||
# 正则表达式模式
|
||||
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||
|
||||
# Audio/Video Settings
|
||||
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
||||
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
||||
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
|
||||
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
||||
AUDIO_FORMAT_TO_MIMETYPE = {
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"ogg": "audio/ogg",
|
||||
}
|
||||
|
||||
VIDEO_FORMAT_TO_MIMETYPE = {
|
||||
"mp4": "video/mp4",
|
||||
"mov": "video/quicktime",
|
||||
"avi": "video/x-msvideo",
|
||||
"webm": "video/webm",
|
||||
}
|
||||
|
||||
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
|
||||
DEFAULT_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
@@ -1,135 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Optional
|
||||
import platform
|
||||
|
||||
# ANSI转义序列颜色代码
|
||||
COLORS = {
|
||||
'DEBUG': '\033[34m', # 蓝色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[1;31m' # 红色加粗
|
||||
}
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == 'Windows':
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""
|
||||
自定义的日志格式化器,添加颜色支持
|
||||
"""
|
||||
|
||||
def format(self, record):
|
||||
# 获取对应级别的颜色代码
|
||||
color = COLORS.get(record.levelname, '')
|
||||
# 添加颜色代码和重置代码
|
||||
record.levelname = f"{color}{record.levelname}\033[0m"
|
||||
return super().format(record)
|
||||
|
||||
|
||||
# 日志格式
|
||||
FORMATTER = ColoredFormatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||||
)
|
||||
|
||||
# 日志级别映射
|
||||
LOG_LEVELS = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@staticmethod
|
||||
def setup_logger(
|
||||
name: str,
|
||||
level: str = "debug",
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
设置并获取logger
|
||||
:param name: logger名称
|
||||
:param level: 日志级别
|
||||
:return: logger实例
|
||||
"""
|
||||
if name in Logger._loggers:
|
||||
return Logger._loggers[name]
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(LOG_LEVELS.get(level.lower(), logging.INFO))
|
||||
logger.propagate = False
|
||||
|
||||
# 添加控制台输出
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(FORMATTER)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
Logger._loggers[name] = logger
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def get_logger(name: str) -> Optional[logging.Logger]:
|
||||
"""
|
||||
获取已存在的logger
|
||||
:param name: logger名称
|
||||
:return: logger实例或None
|
||||
"""
|
||||
return Logger._loggers.get(name)
|
||||
|
||||
|
||||
# 预定义的loggers
|
||||
def get_openai_logger():
|
||||
return Logger.setup_logger("openai")
|
||||
|
||||
|
||||
def get_gemini_logger():
|
||||
return Logger.setup_logger("gemini")
|
||||
|
||||
|
||||
def get_chat_logger():
|
||||
return Logger.setup_logger("chat")
|
||||
|
||||
|
||||
def get_model_logger():
|
||||
return Logger.setup_logger("model")
|
||||
|
||||
|
||||
def get_security_logger():
|
||||
return Logger.setup_logger("security")
|
||||
|
||||
|
||||
def get_key_manager_logger():
|
||||
return Logger.setup_logger("key_manager")
|
||||
|
||||
|
||||
def get_main_logger():
|
||||
return Logger.setup_logger("main")
|
||||
|
||||
|
||||
def get_embeddings_logger():
|
||||
return Logger.setup_logger("embeddings")
|
||||
|
||||
|
||||
def get_request_logger():
|
||||
return Logger.setup_logger("request")
|
||||
|
||||
|
||||
def get_retry_logger():
|
||||
return Logger.setup_logger("retry")
|
||||
|
||||
|
||||
def get_image_create_logger():
|
||||
return Logger.setup_logger("image_create")
|
||||
@@ -1,26 +1,27 @@
|
||||
from fastapi import HTTPException, Header
|
||||
from typing import Optional
|
||||
from app.core.logger import get_security_logger
|
||||
from app.core.config import settings
|
||||
|
||||
from fastapi import Header, HTTPException
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_security_logger
|
||||
|
||||
logger = get_security_logger()
|
||||
|
||||
|
||||
def verify_auth_token(token: str) -> bool:
|
||||
return token == settings.AUTH_TOKEN
|
||||
|
||||
|
||||
class SecurityService:
|
||||
def __init__(self, allowed_tokens: list, auth_token: str):
|
||||
self.allowed_tokens = allowed_tokens
|
||||
self.auth_token = auth_token
|
||||
|
||||
async def verify_key(self, key: str):
|
||||
if key not in self.allowed_tokens and key != self.auth_token:
|
||||
if key not in settings.ALLOWED_TOKENS and key != settings.AUTH_TOKEN:
|
||||
logger.error("Invalid key")
|
||||
raise HTTPException(status_code=401, detail="Invalid key")
|
||||
return key
|
||||
|
||||
async def verify_authorization(
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
if not authorization:
|
||||
logger.error("Missing Authorization header")
|
||||
@@ -33,31 +34,57 @@ class SecurityService:
|
||||
)
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token not in self.allowed_tokens and token != self.auth_token:
|
||||
if token not in settings.ALLOWED_TOKENS and token != settings.AUTH_TOKEN:
|
||||
logger.error("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
return token
|
||||
|
||||
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
|
||||
async def verify_goog_api_key(
|
||||
self, x_goog_api_key: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
"""验证Google API Key"""
|
||||
if not x_goog_api_key:
|
||||
logger.error("Missing x-goog-api-key header")
|
||||
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
||||
|
||||
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
|
||||
if (
|
||||
x_goog_api_key not in settings.ALLOWED_TOKENS
|
||||
and x_goog_api_key != settings.AUTH_TOKEN
|
||||
):
|
||||
logger.error("Invalid x-goog-api-key")
|
||||
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
||||
|
||||
return x_goog_api_key
|
||||
|
||||
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
|
||||
async def verify_auth_token(
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
if not authorization:
|
||||
logger.error("Missing auth_token header")
|
||||
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token != self.auth_token:
|
||||
if token != settings.AUTH_TOKEN:
|
||||
logger.error("Invalid auth_token")
|
||||
raise HTTPException(status_code=401, detail="Invalid auth_token")
|
||||
|
||||
return token
|
||||
|
||||
async def verify_key_or_goog_api_key(
|
||||
self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
"""验证URL中的key或请求头中的x-goog-api-key"""
|
||||
# 如果URL中的key有效,直接返回
|
||||
if key in settings.ALLOWED_TOKENS or key == settings.AUTH_TOKEN:
|
||||
return key
|
||||
|
||||
# 否则检查请求头中的x-goog-api-key
|
||||
if not x_goog_api_key:
|
||||
logger.error("Invalid key and missing x-goog-api-key header")
|
||||
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
|
||||
|
||||
if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN:
|
||||
logger.error("Invalid key and invalid x-goog-api-key")
|
||||
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
|
||||
|
||||
return x_goog_api_key
|
||||
3
app/database/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
数据库模块
|
||||
"""
|
||||
74
app/database/connection.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
数据库连接池模块
|
||||
"""
|
||||
from pathlib import Path
|
||||
from databases import Database
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
# from sqlalchemy.orm import sessionmaker # 不再需要
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
|
||||
# 数据库URL
|
||||
if settings.DATABASE_TYPE == "sqlite":
|
||||
# 确保 data 目录存在
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
db_path = data_dir / settings.SQLITE_DATABASE
|
||||
DATABASE_URL = f"sqlite:///{db_path}"
|
||||
elif settings.DATABASE_TYPE == "mysql":
|
||||
if settings.MYSQL_SOCKET:
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
|
||||
else:
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
||||
else:
|
||||
raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")
|
||||
|
||||
# 创建数据库引擎
|
||||
# pool_pre_ping=True: 在从连接池获取连接前执行简单的 "ping" 测试,确保连接有效
|
||||
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
|
||||
|
||||
# 创建元数据对象
|
||||
metadata = MetaData()
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base(metadata=metadata)
|
||||
|
||||
# 创建数据库连接池,并配置连接池参数,在sqlite中不使用连接池
|
||||
# min_size/max_size: 连接池的最小/最大连接数
|
||||
# pool_recycle=3600: 连接在池中允许存在的最大秒数(生命周期)。
|
||||
# 设置为 3600 秒(1小时),确保在 MySQL 默认的 wait_timeout (通常8小时) 或其他网络超时之前回收连接。
|
||||
# 如果遇到连接失效问题,可以尝试调低此值,使其小于实际的 wait_timeout 或网络超时时间。
|
||||
# databases 库会自动处理连接失效后的重连尝试。
|
||||
if settings.DATABASE_TYPE == "sqlite":
|
||||
database = Database(DATABASE_URL)
|
||||
else:
|
||||
database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800) # Reduced recycle time to 30 mins
|
||||
|
||||
# 移除了 SessionLocal 和 get_db 函数
|
||||
|
||||
# --- Async connection functions for lifespan/async routes ---
|
||||
async def connect_to_db():
|
||||
"""
|
||||
连接到数据库
|
||||
"""
|
||||
try:
|
||||
await database.connect()
|
||||
logger.info(f"Connected to {settings.DATABASE_TYPE}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to database: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def disconnect_from_db():
|
||||
"""
|
||||
断开数据库连接
|
||||
"""
|
||||
try:
|
||||
await database.disconnect()
|
||||
logger.info(f"Disconnected from {settings.DATABASE_TYPE}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disconnect from database: {str(e)}")
|
||||
77
app/database/initialization.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
数据库初始化模块
|
||||
"""
|
||||
from dotenv import dotenv_values
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.connection import engine, Base
|
||||
from app.database.models import Settings
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
|
||||
|
||||
def create_tables():
|
||||
"""
|
||||
创建数据库表
|
||||
"""
|
||||
try:
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(engine)
|
||||
logger.info("Database tables created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create database tables: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def import_env_to_settings():
|
||||
"""
|
||||
将.env文件中的配置项导入到t_settings表中
|
||||
"""
|
||||
try:
|
||||
# 获取.env文件中的所有配置项
|
||||
env_values = dotenv_values(".env")
|
||||
|
||||
# 获取检查器
|
||||
inspector = inspect(engine)
|
||||
|
||||
# 检查t_settings表是否存在
|
||||
if "t_settings" in inspector.get_table_names():
|
||||
# 使用Session进行数据库操作
|
||||
with Session(engine) as session:
|
||||
# 获取所有现有的配置项
|
||||
current_settings = {setting.key: setting for setting in session.query(Settings).all()}
|
||||
|
||||
# 遍历所有配置项
|
||||
for key, value in env_values.items():
|
||||
# 检查配置项是否已存在
|
||||
if key not in current_settings:
|
||||
# 插入配置项
|
||||
new_setting = Settings(key=key, value=value)
|
||||
session.add(new_setting)
|
||||
logger.info(f"Inserted setting: {key}")
|
||||
|
||||
# 提交事务
|
||||
session.commit()
|
||||
|
||||
logger.info("Environment variables imported to settings table successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import environment variables to settings table: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""
|
||||
初始化数据库
|
||||
"""
|
||||
try:
|
||||
# 创建表
|
||||
create_tables()
|
||||
|
||||
# 导入环境变量
|
||||
import_env_to_settings()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
raise
|
||||
61
app/database/models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
数据库模型模块
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean # 添加 Boolean
|
||||
|
||||
from app.database.connection import Base
|
||||
|
||||
|
||||
class Settings(Base):
|
||||
"""
|
||||
设置表,对应.env中的配置项
|
||||
"""
|
||||
__tablename__ = "t_settings"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(100), nullable=False, unique=True, comment="配置项键名")
|
||||
value = Column(Text, nullable=True, comment="配置项值")
|
||||
description = Column(String(255), nullable=True, comment="配置项描述")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Settings(key='{self.key}', value='{self.value}')>"
|
||||
|
||||
|
||||
class ErrorLog(Base):
|
||||
"""
|
||||
错误日志表
|
||||
"""
|
||||
__tablename__ = "t_error_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
gemini_key = Column(String(100), nullable=True, comment="Gemini API密钥")
|
||||
model_name = Column(String(100), nullable=True, comment="模型名称")
|
||||
error_type = Column(String(50), nullable=True, comment="错误类型")
|
||||
error_log = Column(Text, nullable=True, comment="错误日志")
|
||||
error_code = Column(Integer, nullable=True, comment="错误代码")
|
||||
request_msg = Column(JSON, nullable=True, comment="请求消息")
|
||||
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ErrorLog(id='{self.id}', gemini_key='{self.gemini_key}')>"
|
||||
|
||||
# 新增 RequestLog 模型
|
||||
class RequestLog(Base):
|
||||
"""
|
||||
API 请求日志表
|
||||
"""
|
||||
__tablename__ = "t_request_log"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
||||
model_name = Column(String(100), nullable=True, comment="模型名称")
|
||||
api_key = Column(String(100), nullable=True, comment="使用的API密钥") # 考虑安全性,后续可优化
|
||||
is_success = Column(Boolean, nullable=False, comment="请求是否成功")
|
||||
status_code = Column(Integer, nullable=True, comment="API响应状态码")
|
||||
latency_ms = Column(Integer, nullable=True, comment="请求耗时(毫秒)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
||||
419
app/database/services.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
数据库服务模块
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
||||
import json
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings, ErrorLog, RequestLog
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
|
||||
|
||||
async def get_all_settings() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有设置
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 设置列表
|
||||
"""
|
||||
try:
|
||||
query = select(Settings)
|
||||
result = await database.fetch_all(query)
|
||||
return [dict(row) for row in result]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all settings: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_setting(key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定键的设置
|
||||
|
||||
Args:
|
||||
key: 设置键名
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 设置信息,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
query = select(Settings).where(Settings.key == key)
|
||||
result = await database.fetch_one(query)
|
||||
return dict(result) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get setting {key}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def update_setting(key: str, value: str, description: Optional[str] = None) -> bool:
|
||||
"""
|
||||
更新设置
|
||||
|
||||
Args:
|
||||
key: 设置键名
|
||||
value: 设置值
|
||||
description: 设置描述
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
# 检查设置是否存在
|
||||
setting = await get_setting(key)
|
||||
|
||||
if setting:
|
||||
# 更新设置
|
||||
query = (
|
||||
update(Settings)
|
||||
.where(Settings.key == key)
|
||||
.values(
|
||||
value=value,
|
||||
description=description if description else setting["description"],
|
||||
updated_at=datetime.now() # Use datetime.now()
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
logger.info(f"Updated setting: {key}")
|
||||
return True
|
||||
else:
|
||||
# 插入设置
|
||||
query = (
|
||||
insert(Settings)
|
||||
.values(
|
||||
key=key,
|
||||
value=value,
|
||||
description=description,
|
||||
created_at=datetime.now(), # Use datetime.now()
|
||||
updated_at=datetime.now() # Use datetime.now()
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
logger.info(f"Inserted setting: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update setting {key}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def add_error_log(
|
||||
gemini_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
error_log: Optional[str] = None,
|
||||
error_code: Optional[int] = None,
|
||||
request_msg: Optional[Union[Dict[str, Any], str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
添加错误日志
|
||||
|
||||
Args:
|
||||
gemini_key: Gemini API密钥
|
||||
error_log: 错误日志
|
||||
error_code: 错误代码 (例如 HTTP 状态码)
|
||||
request_msg: 请求消息
|
||||
|
||||
Returns:
|
||||
bool: 是否添加成功
|
||||
"""
|
||||
try:
|
||||
# 如果request_msg是字典,则转换为JSON字符串
|
||||
if isinstance(request_msg, dict):
|
||||
request_msg_json = request_msg
|
||||
elif isinstance(request_msg, str):
|
||||
try:
|
||||
request_msg_json = json.loads(request_msg)
|
||||
except json.JSONDecodeError:
|
||||
request_msg_json = {"message": request_msg}
|
||||
else:
|
||||
request_msg_json = None
|
||||
|
||||
# 插入错误日志
|
||||
query = (
|
||||
insert(ErrorLog)
|
||||
.values(
|
||||
gemini_key=gemini_key,
|
||||
error_type=error_type,
|
||||
error_log=error_log,
|
||||
model_name=model_name,
|
||||
error_code=error_code,
|
||||
request_msg=request_msg_json,
|
||||
request_time=datetime.now()
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
logger.info(f"Added error log for key: {gemini_key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add error log: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_error_logs(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
key_search: Optional[str] = None,
|
||||
error_search: Optional[str] = None,
|
||||
error_code_search: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
sort_by: str = 'id', # 新增排序字段
|
||||
sort_order: str = 'desc' # 新增排序顺序 ('asc' or 'desc')
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取错误日志,支持搜索、日期过滤和排序
|
||||
|
||||
Args:
|
||||
limit (int): 限制数量
|
||||
offset (int): 偏移量
|
||||
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
||||
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
||||
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
||||
start_date (Optional[datetime]): 开始日期时间
|
||||
end_date (Optional[datetime]): 结束日期时间
|
||||
sort_by (str): 排序字段 (例如 'id', 'request_time')
|
||||
sort_order (str): 排序顺序 ('asc' or 'desc')
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 错误日志列表
|
||||
"""
|
||||
try:
|
||||
query = select(
|
||||
ErrorLog.id,
|
||||
ErrorLog.gemini_key,
|
||||
ErrorLog.model_name,
|
||||
ErrorLog.error_type,
|
||||
ErrorLog.error_log,
|
||||
ErrorLog.error_code,
|
||||
ErrorLog.request_time
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if key_search:
|
||||
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
||||
if error_search:
|
||||
query = query.where(
|
||||
(ErrorLog.error_type.ilike(f"%{error_search}%")) |
|
||||
(ErrorLog.error_log.ilike(f"%{error_search}%"))
|
||||
)
|
||||
if start_date:
|
||||
query = query.where(ErrorLog.request_time >= start_date)
|
||||
if end_date:
|
||||
# Use the datetime object directly for comparison
|
||||
query = query.where(ErrorLog.request_time < end_date)
|
||||
if error_code_search:
|
||||
try:
|
||||
# Attempt to convert search string to integer for exact match
|
||||
error_code_int = int(error_code_search)
|
||||
query = query.where(ErrorLog.error_code == error_code_int)
|
||||
except ValueError:
|
||||
# If conversion fails, log a warning and potentially skip this filter
|
||||
# or handle as needed (e.g., return no results for invalid code format)
|
||||
logger.warning(f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
||||
# Optionally, force no results if the format is invalid:
|
||||
# query = query.where(False) # This ensures no rows are returned
|
||||
|
||||
# 添加排序逻辑
|
||||
sort_column = getattr(ErrorLog, sort_by, ErrorLog.id) # 获取排序字段,默认为 id
|
||||
if sort_order.lower() == 'asc':
|
||||
query = query.order_by(asc(sort_column))
|
||||
else:
|
||||
query = query.order_by(desc(sort_column))
|
||||
|
||||
# Apply limit and offset
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await database.fetch_all(query)
|
||||
return [dict(row) for row in result]
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error logs with filters: {str(e)}") # Use exception for stack trace
|
||||
raise
|
||||
|
||||
|
||||
async def get_error_logs_count(
|
||||
key_search: Optional[str] = None,
|
||||
error_search: Optional[str] = None,
|
||||
error_code_search: Optional[str] = None, # Added error code search
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> int:
|
||||
"""
|
||||
获取符合条件的错误日志总数
|
||||
|
||||
Args:
|
||||
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
||||
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
||||
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
||||
start_date (Optional[datetime]): 开始日期时间
|
||||
end_date (Optional[datetime]): 结束日期时间
|
||||
|
||||
Returns:
|
||||
int: 日志总数
|
||||
"""
|
||||
try:
|
||||
query = select(func.count()).select_from(ErrorLog)
|
||||
|
||||
# Apply the same filters as get_error_logs
|
||||
if key_search:
|
||||
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
||||
if error_search:
|
||||
query = query.where(
|
||||
(ErrorLog.error_type.ilike(f"%{error_search}%")) |
|
||||
(ErrorLog.error_log.ilike(f"%{error_search}%"))
|
||||
)
|
||||
if start_date:
|
||||
query = query.where(ErrorLog.request_time >= start_date)
|
||||
if end_date:
|
||||
# Use the datetime object directly for comparison
|
||||
query = query.where(ErrorLog.request_time < end_date)
|
||||
if error_code_search:
|
||||
try:
|
||||
# Attempt to convert search string to integer for exact match
|
||||
error_code_int = int(error_code_search)
|
||||
query = query.where(ErrorLog.error_code == error_code_int)
|
||||
except ValueError:
|
||||
# If conversion fails, log a warning and potentially skip this filter
|
||||
logger.warning(f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
||||
# Optionally, force count to 0 if the format is invalid:
|
||||
# return 0 # Or query = query.where(False) before fetching
|
||||
|
||||
count_result = await database.fetch_one(query)
|
||||
return count_result[0] if count_result else 0
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to count error logs with filters: {str(e)}") # Use exception for stack trace
|
||||
raise
|
||||
|
||||
|
||||
# 新增函数:获取单条错误日志详情
|
||||
async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据 ID 获取单个错误日志的详细信息
|
||||
|
||||
Args:
|
||||
log_id (int): 错误日志的 ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 包含日志详细信息的字典,如果未找到则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(ErrorLog).where(ErrorLog.id == log_id)
|
||||
result = await database.fetch_one(query)
|
||||
if result:
|
||||
# 将 request_msg (JSONB) 转换为字符串以便在 API 中返回
|
||||
log_dict = dict(result)
|
||||
if 'request_msg' in log_dict and log_dict['request_msg'] is not None:
|
||||
# 确保即使是 None 或非 JSON 数据也能处理
|
||||
try:
|
||||
log_dict['request_msg'] = json.dumps(log_dict['request_msg'], ensure_ascii=False, indent=2)
|
||||
except TypeError:
|
||||
log_dict['request_msg'] = str(log_dict['request_msg']) # Fallback to string
|
||||
return log_dict
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
# --- 异步删除函数 (使用 databases 库) ---
|
||||
|
||||
async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
"""
|
||||
根据提供的 ID 列表批量删除错误日志 (异步)。
|
||||
|
||||
Args:
|
||||
log_ids: 要删除的错误日志 ID 列表。
|
||||
|
||||
Returns:
|
||||
int: 实际删除的日志数量。
|
||||
"""
|
||||
if not log_ids:
|
||||
return 0
|
||||
try:
|
||||
# 使用 databases 执行删除
|
||||
query = delete(ErrorLog).where(ErrorLog.id.in_(log_ids))
|
||||
# execute 返回受影响的行数,但 databases 库的 execute 不直接返回 rowcount
|
||||
# 我们需要先查询是否存在,或者依赖数据库约束/触发器(如果适用)
|
||||
# 或者,我们可以执行删除并假设成功,除非抛出异常
|
||||
# 为了简单起见,我们执行删除并记录日志,不精确返回删除数量
|
||||
# 如果需要精确数量,需要先执行 SELECT COUNT(*)
|
||||
await database.execute(query)
|
||||
# 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量
|
||||
# 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试
|
||||
logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}")
|
||||
return len(log_ids) # 返回尝试删除的数量
|
||||
except Exception as e:
|
||||
# 数据库连接或执行错误
|
||||
logger.error(f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True)
|
||||
raise # Re-raise the exception for the router to handle
|
||||
|
||||
async def delete_error_log_by_id(log_id: int) -> bool:
|
||||
"""
|
||||
根据 ID 删除单个错误日志 (异步)。
|
||||
|
||||
Args:
|
||||
log_id: 要删除的错误日志 ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功删除返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
# 先检查是否存在 (可选,但更明确)
|
||||
check_query = select(ErrorLog.id).where(ErrorLog.id == log_id)
|
||||
exists = await database.fetch_one(check_query)
|
||||
|
||||
if not exists:
|
||||
logger.warning(f"Attempted to delete non-existent error log with ID: {log_id}")
|
||||
return False # 或者可以抛出 404 异常,由路由处理
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(ErrorLog).where(ErrorLog.id == log_id)
|
||||
await database.execute(delete_query)
|
||||
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True)
|
||||
raise # Re-raise the exception for the router to handle
|
||||
|
||||
# --- RequestLog Services (保持异步) ---
|
||||
|
||||
# 新增函数:添加请求日志
|
||||
async def add_request_log(
|
||||
model_name: Optional[str],
|
||||
api_key: Optional[str],
|
||||
is_success: bool,
|
||||
status_code: Optional[int] = None,
|
||||
latency_ms: Optional[int] = None,
|
||||
request_time: Optional[datetime] = None
|
||||
) -> bool:
|
||||
"""
|
||||
添加 API 请求日志
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
api_key: 使用的 API 密钥
|
||||
is_success: 请求是否成功
|
||||
status_code: API 响应状态码
|
||||
latency_ms: 请求耗时(毫秒)
|
||||
request_time: 请求发生时间 (如果为 None, 则使用当前时间)
|
||||
|
||||
Returns:
|
||||
bool: 是否添加成功
|
||||
"""
|
||||
try:
|
||||
log_time = request_time if request_time else datetime.now()
|
||||
|
||||
query = insert(RequestLog).values(
|
||||
request_time=log_time,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms
|
||||
)
|
||||
await database.execute(query)
|
||||
# logger.debug(f"Added request log: key={api_key[:4]}..., success={is_success}, model={model_name}") # Use debug level
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add request log: {str(e)}")
|
||||
return False
|
||||
78
app/domain/gemini_models.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||
|
||||
|
||||
class SafetySetting(BaseModel):
|
||||
category: Optional[
|
||||
Literal[
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
]
|
||||
] = None
|
||||
threshold: Optional[
|
||||
Literal[
|
||||
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_NONE",
|
||||
"OFF",
|
||||
]
|
||||
] = None
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
stopSequences: Optional[List[str]] = None
|
||||
responseMimeType: Optional[str] = None
|
||||
responseSchema: Optional[Dict[str, Any]] = None
|
||||
candidateCount: Optional[int] = 1
|
||||
maxOutputTokens: Optional[int] = None
|
||||
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||
topP: Optional[float] = DEFAULT_TOP_P
|
||||
topK: Optional[int] = DEFAULT_TOP_K
|
||||
presencePenalty: Optional[float] = None
|
||||
frequencyPenalty: Optional[float] = None
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
role: str = "system"
|
||||
parts: List[Dict[str, Any]] | Dict[str, Any]
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
role: str
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiRequest(BaseModel):
|
||||
contents: List[GeminiContent] = []
|
||||
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||
safetySettings: Optional[List[SafetySetting]] = Field(
|
||||
default=None, alias="safety_settings"
|
||||
)
|
||||
generationConfig: Optional[GenerationConfig] = Field(
|
||||
default=None, alias="generation_config"
|
||||
)
|
||||
systemInstruction: Optional[SystemInstruction] = Field(
|
||||
default=None, alias="system_instruction"
|
||||
)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class ResetSelectedKeysRequest(BaseModel):
|
||||
keys: List[str]
|
||||
key_type: str
|
||||
|
||||
|
||||
class VerifySelectedKeysRequest(BaseModel):
|
||||
keys: List[str]
|
||||
35
app/domain/openai_models.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str = DEFAULT_MODEL
|
||||
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||
stream: Optional[bool] = False
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = DEFAULT_TOP_P
|
||||
top_k: Optional[int] = DEFAULT_TOP_K
|
||||
stop: Optional[Union[List[str],str]] = None
|
||||
reasoning_effort: Optional[str] = None
|
||||
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||
tool_choice: Optional[str] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-004"
|
||||
encoding_format: Optional[str] = "float"
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = "imagen-3.0-generate-002"
|
||||
prompt: str = ""
|
||||
n: int = 1
|
||||
size: Optional[str] = "1024x1024"
|
||||
quality: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
response_format: Optional[str] = "url"
|
||||
140
app/exception/exceptions.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.log.logger import get_exceptions_logger
|
||||
|
||||
logger = get_exceptions_logger()
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""API错误基类"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str, error_code: str = None):
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.error_code = error_code or "api_error"
|
||||
super().__init__(self.detail)
|
||||
|
||||
|
||||
class AuthenticationError(APIError):
|
||||
"""认证错误"""
|
||||
|
||||
def __init__(self, detail: str = "Authentication failed"):
|
||||
super().__init__(
|
||||
status_code=401, detail=detail, error_code="authentication_error"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationError(APIError):
|
||||
"""授权错误"""
|
||||
|
||||
def __init__(self, detail: str = "Not authorized to access this resource"):
|
||||
super().__init__(
|
||||
status_code=403, detail=detail, error_code="authorization_error"
|
||||
)
|
||||
|
||||
|
||||
class ResourceNotFoundError(APIError):
|
||||
"""资源未找到错误"""
|
||||
|
||||
def __init__(self, detail: str = "Resource not found"):
|
||||
super().__init__(
|
||||
status_code=404, detail=detail, error_code="resource_not_found"
|
||||
)
|
||||
|
||||
|
||||
class ModelNotSupportedError(APIError):
|
||||
"""模型不支持错误"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
super().__init__(
|
||||
status_code=400,
|
||||
detail=f"Model {model} is not supported",
|
||||
error_code="model_not_supported",
|
||||
)
|
||||
|
||||
|
||||
class APIKeyError(APIError):
|
||||
"""API密钥错误"""
|
||||
|
||||
def __init__(self, detail: str = "Invalid or expired API key"):
|
||||
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
|
||||
|
||||
|
||||
class ServiceUnavailableError(APIError):
|
||||
"""服务不可用错误"""
|
||||
|
||||
def __init__(self, detail: str = "Service temporarily unavailable"):
|
||||
super().__init__(
|
||||
status_code=503, detail=detail, error_code="service_unavailable"
|
||||
)
|
||||
|
||||
|
||||
def setup_exception_handlers(app: FastAPI) -> None:
|
||||
"""
|
||||
设置应用程序的异常处理器
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.exception_handler(APIError)
|
||||
async def api_error_handler(request: Request, exc: APIError):
|
||||
"""处理API错误"""
|
||||
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"code": exc.error_code, "message": exc.detail}},
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""处理HTTP异常"""
|
||||
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"code": "http_error", "message": exc.detail}},
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
):
|
||||
"""处理请求验证错误"""
|
||||
error_details = []
|
||||
for error in exc.errors():
|
||||
error_details.append(
|
||||
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
|
||||
)
|
||||
|
||||
logger.error(f"Validation Error: {error_details}")
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"code": "validation_error",
|
||||
"message": "Request validation failed",
|
||||
"details": error_details,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""处理通用异常"""
|
||||
logger.exception(f"Unhandled Exception: {str(exc)}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": "internal_server_error",
|
||||
"message": "An unexpected error occurred",
|
||||
}
|
||||
},
|
||||
)
|
||||
32
app/handler/error_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
|
||||
@asynccontextmanager
|
||||
async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
|
||||
"""
|
||||
一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
|
||||
|
||||
Args:
|
||||
logger: 用于记录日志的 Logger 实例。
|
||||
operation_name: 操作的名称,用于日志记录和错误详情。
|
||||
success_message: 操作成功时记录的自定义消息 (可选)。
|
||||
failure_message: 操作失败时记录的自定义消息 (可选)。
|
||||
"""
|
||||
default_success_msg = f"{operation_name} request successful"
|
||||
default_failure_msg = f"{operation_name} request failed"
|
||||
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
try:
|
||||
yield
|
||||
logger.info(success_message or default_success_msg)
|
||||
except HTTPException as http_exc:
|
||||
# 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情
|
||||
logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
# 对于其他所有异常,记录错误并抛出标准的 500 错误
|
||||
logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal server error during {operation_name}"
|
||||
) from e
|
||||
359
app/handler/message_converter.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from app.core.constants import (
|
||||
AUDIO_FORMAT_TO_MIMETYPE,
|
||||
DATA_URL_PATTERN,
|
||||
IMAGE_URL_PATTERN,
|
||||
MAX_AUDIO_SIZE_BYTES,
|
||||
MAX_VIDEO_SIZE_BYTES,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
SUPPORTED_ROLES,
|
||||
SUPPORTED_VIDEO_FORMATS,
|
||||
VIDEO_FORMAT_TO_MIMETYPE,
|
||||
)
|
||||
from app.log.logger import get_message_converter_logger
|
||||
|
||||
logger = get_message_converter_logger()
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""消息转换器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
|
||||
def _get_mime_type_and_data(base64_string):
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据。
|
||||
|
||||
参数:
|
||||
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
返回:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith("data:"):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = DATA_URL_PATTERN
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = (
|
||||
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
|
||||
def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||
if image_url.startswith("data:image"):
|
||||
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
||||
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
|
||||
else:
|
||||
encoded_data = _convert_image_to_base64(image_url)
|
||||
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
|
||||
|
||||
|
||||
def _convert_image_to_base64(url: str) -> str:
|
||||
"""
|
||||
将图片URL转换为base64编码
|
||||
Args:
|
||||
url: 图片URL
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
"""
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode("utf-8")
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
|
||||
|
||||
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理可能包含图片URL的文本,提取图片并转换为base64
|
||||
|
||||
Args:
|
||||
text: 可能包含图片URL的文本
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
||||
"""
|
||||
parts = []
|
||||
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
||||
if img_url_match:
|
||||
# 提取URL
|
||||
img_url = img_url_match.group(2)
|
||||
# 将URL对应的图片转换为base64
|
||||
try:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
except Exception:
|
||||
# 如果转换失败,回退到文本模式
|
||||
parts.append({"text": text})
|
||||
else:
|
||||
# 没有图片URL,作为纯文本处理
|
||||
parts.append({"text": text})
|
||||
return parts
|
||||
|
||||
|
||||
class OpenAIMessageConverter(MessageConverter):
|
||||
"""OpenAI消息格式转换器"""
|
||||
|
||||
def _validate_media_data(
|
||||
self, format: str, data: str, supported_formats: List[str], max_size: int
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Validates format and size of Base64 media data."""
|
||||
if format.lower() not in supported_formats:
|
||||
logger.error(
|
||||
f"Unsupported media format: {format}. Supported: {supported_formats}"
|
||||
)
|
||||
raise ValueError(f"Unsupported media format: {format}")
|
||||
|
||||
try:
|
||||
# Decode Base64 to check size
|
||||
# Be careful with memory usage for very large files
|
||||
# Consider streaming decoding or checking length heuristic first if memory is a concern
|
||||
decoded_data = base64.b64decode(
|
||||
data, validate=True
|
||||
) # Use validate=True for stricter check
|
||||
if len(decoded_data) > max_size:
|
||||
logger.error(
|
||||
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
|
||||
)
|
||||
# No need to return decoded_data, just the original base64 if valid
|
||||
return data
|
||||
except base64.binascii.Error as e:
|
||||
logger.error(f"Invalid Base64 data provided: {e}")
|
||||
raise ValueError("Invalid Base64 data")
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating media data: {e}")
|
||||
raise
|
||||
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction_parts = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
parts = []
|
||||
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
for content_item in msg["content"]:
|
||||
if not isinstance(content_item, dict):
|
||||
# Skip non-dict items if any unexpected format appears
|
||||
logger.warning(
|
||||
f"Skipping unexpected content item format: {type(content_item)}"
|
||||
)
|
||||
continue
|
||||
|
||||
content_type = content_item.get("type")
|
||||
|
||||
if content_type == "text" and content_item.get("text"):
|
||||
parts.append({"text": content_item["text"]})
|
||||
elif content_type == "image_url" and content_item.get(
|
||||
"image_url", {}
|
||||
).get("url"):
|
||||
try:
|
||||
parts.append(
|
||||
_convert_image(content_item["image_url"]["url"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
|
||||
)
|
||||
# Decide how to handle: skip part, add error text, etc.
|
||||
parts.append(
|
||||
{
|
||||
"text": f"[Error processing image: {content_item['image_url']['url']}]"
|
||||
}
|
||||
)
|
||||
# --- Add handling for input_audio ---
|
||||
elif content_type == "input_audio" and content_item.get(
|
||||
"input_audio"
|
||||
):
|
||||
audio_info = content_item["input_audio"]
|
||||
audio_data = audio_info.get("data")
|
||||
audio_format = audio_info.get("format", "").lower()
|
||||
|
||||
if not audio_data or not audio_format:
|
||||
logger.warning(
|
||||
"Skipping audio part due to missing data or format."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate size and format
|
||||
validated_data = self._validate_media_data(
|
||||
audio_format,
|
||||
audio_data,
|
||||
SUPPORTED_AUDIO_FORMATS,
|
||||
MAX_AUDIO_SIZE_BYTES,
|
||||
)
|
||||
|
||||
# Get MIME type
|
||||
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
|
||||
if not mime_type:
|
||||
# Should not happen if format validation passed, but double-check
|
||||
logger.error(
|
||||
f"Could not find MIME type for supported format: {audio_format}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Internal error: MIME type mapping missing for {audio_format}"
|
||||
)
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": mime_type,
|
||||
"data": validated_data, # Use the validated Base64 data
|
||||
}
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Successfully added audio part (format: {audio_format})"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Skipping audio part due to validation error: {e}"
|
||||
)
|
||||
parts.append({"text": f"[Error processing audio: {e}]"})
|
||||
except Exception:
|
||||
logger.exception("Unexpected error processing audio part.")
|
||||
parts.append(
|
||||
{"text": "[Unexpected error processing audio]"}
|
||||
)
|
||||
|
||||
elif content_type == "input_video" and content_item.get(
|
||||
"input_video"
|
||||
):
|
||||
video_info = content_item["input_video"]
|
||||
video_data = video_info.get("data")
|
||||
video_format = video_info.get("format", "").lower()
|
||||
|
||||
if not video_data or not video_format:
|
||||
logger.warning(
|
||||
"Skipping video part due to missing data or format."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
validated_data = self._validate_media_data(
|
||||
video_format,
|
||||
video_data,
|
||||
SUPPORTED_VIDEO_FORMATS,
|
||||
MAX_VIDEO_SIZE_BYTES,
|
||||
)
|
||||
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Internal error: MIME type mapping missing for {video_format}"
|
||||
)
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": mime_type,
|
||||
"data": validated_data,
|
||||
}
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Successfully added video part (format: {video_format})"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Skipping video part due to validation error: {e}"
|
||||
)
|
||||
parts.append({"text": f"[Error processing video: {e}]"})
|
||||
except Exception:
|
||||
logger.exception("Unexpected error processing video part.")
|
||||
parts.append(
|
||||
{"text": "[Unexpected error processing video]"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Log unrecognized but present types
|
||||
if content_type:
|
||||
logger.warning(
|
||||
f"Unsupported content type or missing data in structured content: {content_type}"
|
||||
)
|
||||
|
||||
elif (
|
||||
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
||||
):
|
||||
parts.extend(_process_text_with_image(msg["content"]))
|
||||
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
# Keep existing tool call processing
|
||||
for tool_call in msg["tool_calls"]:
|
||||
function_call = tool_call.get("function", {})
|
||||
# Sanitize arguments loading
|
||||
arguments_str = function_call.get("arguments", "{}")
|
||||
try:
|
||||
function_call["args"] = json.loads(arguments_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments: {arguments_str}"
|
||||
)
|
||||
function_call["args"] = {}
|
||||
if "arguments" in function_call:
|
||||
if "arguments" in function_call:
|
||||
del function_call["arguments"]
|
||||
|
||||
parts.append({"functionCall": function_call})
|
||||
|
||||
if role not in SUPPORTED_ROLES:
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
else:
|
||||
# 如果是最后一条消息,则认为是用户消息
|
||||
if idx == len(messages) - 1:
|
||||
role = "user"
|
||||
else:
|
||||
role = "model"
|
||||
if parts:
|
||||
if role == "system":
|
||||
text_only_parts = [p for p in parts if "text" in p]
|
||||
if len(text_only_parts) != len(parts):
|
||||
logger.warning(
|
||||
"Non-text parts found in system message; discarding them."
|
||||
)
|
||||
if text_only_parts:
|
||||
system_instruction_parts.extend(text_only_parts)
|
||||
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
system_instruction = (
|
||||
None
|
||||
if not system_instruction_parts
|
||||
else {
|
||||
"role": "system",
|
||||
"parts": system_instruction_parts,
|
||||
}
|
||||
)
|
||||
return converted_messages, system_instruction
|
||||
@@ -1,22 +1,23 @@
|
||||
# app/services/chat/response_handler.py
|
||||
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional
|
||||
import time
|
||||
import uuid
|
||||
from app.core.config import settings
|
||||
from app.core.uploader import ImageUploaderFactory
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
"""响应处理器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -27,32 +28,44 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
self.thinking_first = True
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if stream:
|
||||
return _handle_gemini_stream_response(response, model, stream)
|
||||
return _handle_gemini_normal_response(response, model, stream)
|
||||
|
||||
|
||||
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
|
||||
def _handle_openai_stream_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=True, gemini_format=False
|
||||
)
|
||||
if not text and not tool_calls:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "role": "assistant"}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
|
||||
return {
|
||||
template_chunk = {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
|
||||
}
|
||||
if usage_metadata:
|
||||
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
|
||||
return template_chunk
|
||||
|
||||
|
||||
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
|
||||
def _handle_openai_normal_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=False, gemini_format=False
|
||||
)
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
@@ -61,11 +74,15 @@ def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
|
||||
}
|
||||
|
||||
|
||||
@@ -78,59 +95,68 @@ class OpenAIResponseHandler(ResponseHandler):
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
usage_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason)
|
||||
return _handle_openai_normal_response(response, model, finish_reason)
|
||||
|
||||
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
|
||||
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
|
||||
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
|
||||
|
||||
def handle_image_chat_response(
|
||||
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
||||
):
|
||||
if stream:
|
||||
return _handle_openai_stream_image_response(image_str,model,finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str,model,finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return _handle_openai_stream_image_response(image_str, model, finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str, model, finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(
|
||||
image_str: str, model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
def _handle_openai_normal_image_response(
|
||||
image_str: str, model: str, finish_reason: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": image_str
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": image_str},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
|
||||
def _extract_result(
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, List[Dict[str, Any]]]:
|
||||
text, tool_calls = "", []
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
@@ -146,13 +172,9 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = _format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
text = _format_execution_result(parts[0]["executableCodeResult"])
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
text = _format_execution_result(parts[0]["codeExecutionResult"])
|
||||
elif "inlineData" in parts[0]:
|
||||
text = _extract_image_data(parts[0])
|
||||
else:
|
||||
@@ -166,10 +188,10 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = (
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
@@ -187,34 +209,47 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
||||
tool_calls = _extract_tool_calls(
|
||||
candidate["content"]["parts"], gemini_format
|
||||
)
|
||||
else:
|
||||
text = "暂无返回"
|
||||
return text, tool_calls
|
||||
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
image_uploader = None
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
base64_data = part["inlineData"]["data"]
|
||||
#将base64_data转成bytes数组
|
||||
# 将base64_data转成bytes数组
|
||||
bytes_data = base64.b64decode(base64_data)
|
||||
upload_response = image_uploader.upload(bytes_data,filename)
|
||||
upload_response = image_uploader.upload(bytes_data, filename)
|
||||
if upload_response.success:
|
||||
text = f""
|
||||
text = f"\n\n\n\n"
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
||||
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def _extract_tool_calls(
|
||||
parts: List[Dict[str, Any]], gemini_format: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""提取工具调用信息"""
|
||||
if not parts or not isinstance(parts, list):
|
||||
return []
|
||||
@@ -250,8 +285,12 @@ def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> Lis
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
def _handle_gemini_stream_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
@@ -260,8 +299,12 @@ def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream:
|
||||
return response
|
||||
|
||||
|
||||
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
def _handle_gemini_normal_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
@@ -279,10 +322,10 @@ def _format_code_block(code_data: dict) -> str:
|
||||
|
||||
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
50
app/handler/retry_handler.py
Normal file
@@ -0,0 +1,50 @@
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_retry_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = get_retry_logger()
|
||||
|
||||
|
||||
class RetryHandler:
|
||||
"""重试处理装饰器"""
|
||||
|
||||
def __init__(self, key_arg: str = "api_key"):
|
||||
self.key_arg = key_arg
|
||||
|
||||
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> T:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(settings.MAX_RETRIES):
|
||||
retries = attempt + 1
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(
|
||||
f"API call failed with error: {str(e)}. Attempt {retries} of {settings.MAX_RETRIES}"
|
||||
)
|
||||
|
||||
# 从函数参数中获取 key_manager
|
||||
key_manager = kwargs.get("key_manager")
|
||||
if key_manager:
|
||||
old_key = kwargs.get(self.key_arg)
|
||||
new_key = await key_manager.handle_api_failure(old_key, retries)
|
||||
if new_key:
|
||||
kwargs[self.key_arg] = new_key
|
||||
logger.info(f"Switched to new API key: {new_key}")
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
|
||||
logger.error(
|
||||
f"All retry attempts failed, raising final exception: {str(last_exception)}"
|
||||
)
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
@@ -1,10 +1,17 @@
|
||||
# app/services/chat/stream_optimizer.py
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from typing import Any, List, AsyncGenerator, Callable
|
||||
from app.core.logger import get_openai_logger, get_gemini_logger
|
||||
from app.core.config import settings
|
||||
from typing import Any, AsyncGenerator, Callable, List
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import (
|
||||
DEFAULT_STREAM_CHUNK_SIZE,
|
||||
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
DEFAULT_STREAM_MAX_DELAY,
|
||||
DEFAULT_STREAM_MIN_DELAY,
|
||||
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
)
|
||||
from app.log.logger import get_gemini_logger, get_openai_logger
|
||||
|
||||
logger_openai = get_openai_logger()
|
||||
logger_gemini = get_gemini_logger()
|
||||
@@ -12,19 +19,21 @@ logger_gemini = get_gemini_logger()
|
||||
|
||||
class StreamOptimizer:
|
||||
"""流式输出优化器
|
||||
|
||||
|
||||
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger=None,
|
||||
min_delay: float = 0.016,
|
||||
max_delay: float = 0.024,
|
||||
short_text_threshold: int = 10,
|
||||
long_text_threshold: int = 50,
|
||||
chunk_size: int = 5):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger=None,
|
||||
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
||||
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
||||
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
|
||||
):
|
||||
"""初始化流式输出优化器
|
||||
|
||||
|
||||
参数:
|
||||
logger: 日志记录器
|
||||
min_delay: 最小延迟时间(秒)
|
||||
@@ -39,13 +48,13 @@ class StreamOptimizer:
|
||||
self.short_text_threshold = short_text_threshold
|
||||
self.long_text_threshold = long_text_threshold
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
|
||||
def calculate_delay(self, text_length: int) -> float:
|
||||
"""根据文本长度计算延迟时间
|
||||
|
||||
|
||||
参数:
|
||||
text_length: 文本长度
|
||||
|
||||
|
||||
返回:
|
||||
延迟时间(秒)
|
||||
"""
|
||||
@@ -58,48 +67,50 @@ class StreamOptimizer:
|
||||
else:
|
||||
# 中等长度文本使用线性插值计算延迟
|
||||
# 使用对数函数使延迟变化更平滑
|
||||
ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold)
|
||||
ratio = math.log(text_length / self.short_text_threshold) / math.log(
|
||||
self.long_text_threshold / self.short_text_threshold
|
||||
)
|
||||
return self.max_delay - ratio * (self.max_delay - self.min_delay)
|
||||
|
||||
|
||||
def split_text_into_chunks(self, text: str) -> List[str]:
|
||||
"""将文本分割成小块
|
||||
|
||||
|
||||
参数:
|
||||
text: 要分割的文本
|
||||
|
||||
|
||||
返回:
|
||||
文本块列表
|
||||
"""
|
||||
return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)]
|
||||
|
||||
async def optimize_stream_output(self,
|
||||
text: str,
|
||||
create_response_chunk: Callable[[str], Any],
|
||||
format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]:
|
||||
return [
|
||||
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
|
||||
]
|
||||
|
||||
async def optimize_stream_output(
|
||||
self,
|
||||
text: str,
|
||||
create_response_chunk: Callable[[str], Any],
|
||||
format_chunk: Callable[[Any], str],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""优化流式输出
|
||||
|
||||
|
||||
参数:
|
||||
text: 要输出的文本
|
||||
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
|
||||
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
|
||||
|
||||
|
||||
返回:
|
||||
异步生成器,生成格式化后的响应块
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
|
||||
# 计算智能延迟时间
|
||||
delay = self.calculate_delay(len(text))
|
||||
if self.logger:
|
||||
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
|
||||
|
||||
|
||||
# 根据文本长度决定输出方式
|
||||
if len(text) >= self.long_text_threshold:
|
||||
# 长文本:分块输出
|
||||
chunks = self.split_text_into_chunks(text)
|
||||
if self.logger:
|
||||
self.logger.info(f"Long text: splitting into {len(chunks)} chunks")
|
||||
for chunk_text in chunks:
|
||||
chunk_response = create_response_chunk(chunk_text)
|
||||
yield format_chunk(chunk_response)
|
||||
@@ -119,7 +130,7 @@ openai_optimizer = StreamOptimizer(
|
||||
max_delay=settings.STREAM_MAX_DELAY,
|
||||
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
gemini_optimizer = StreamOptimizer(
|
||||
@@ -128,5 +139,5 @@ gemini_optimizer = StreamOptimizer(
|
||||
max_delay=settings.STREAM_MAX_DELAY,
|
||||
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE,
|
||||
)
|
||||
225
app/log/logger.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
from typing import Dict, Optional
|
||||
|
||||
# ANSI转义序列颜色代码
|
||||
COLORS = {
|
||||
"DEBUG": "\033[34m", # 蓝色
|
||||
"INFO": "\033[32m", # 绿色
|
||||
"WARNING": "\033[33m", # 黄色
|
||||
"ERROR": "\033[31m", # 红色
|
||||
"CRITICAL": "\033[1;31m", # 红色加粗
|
||||
}
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == "Windows":
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""
|
||||
自定义的日志格式化器,添加颜色支持
|
||||
"""
|
||||
|
||||
def format(self, record):
|
||||
# 获取对应级别的颜色代码
|
||||
color = COLORS.get(record.levelname, "")
|
||||
# 添加颜色代码和重置代码
|
||||
record.levelname = f"{color}{record.levelname}\033[0m"
|
||||
# 创建包含文件名和行号的固定宽度字符串
|
||||
record.fileloc = f"[{record.filename}:{record.lineno}]"
|
||||
return super().format(record)
|
||||
|
||||
|
||||
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
||||
FORMATTER = ColoredFormatter(
|
||||
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
|
||||
)
|
||||
|
||||
# 日志级别映射
|
||||
LOG_LEVELS = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@staticmethod
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
设置并获取logger
|
||||
:param name: logger名称
|
||||
:return: logger实例
|
||||
"""
|
||||
# 导入 settings 对象
|
||||
from app.config.config import settings
|
||||
|
||||
# 从全局配置获取日志级别
|
||||
log_level_str = settings.LOG_LEVEL.lower()
|
||||
level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
||||
|
||||
if name in Logger._loggers:
|
||||
# 如果 logger 已存在,检查并更新其级别(如果需要)
|
||||
existing_logger = Logger._loggers[name]
|
||||
if existing_logger.level != level:
|
||||
existing_logger.setLevel(level)
|
||||
return existing_logger
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
logger.propagate = False
|
||||
|
||||
# 添加控制台输出
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(FORMATTER)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
Logger._loggers[name] = logger
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def get_logger(name: str) -> Optional[logging.Logger]:
|
||||
"""
|
||||
获取已存在的logger
|
||||
:param name: logger名称
|
||||
:return: logger实例或None
|
||||
"""
|
||||
return Logger._loggers.get(name)
|
||||
|
||||
@staticmethod
|
||||
def update_log_levels(log_level: str):
|
||||
"""
|
||||
根据当前的全局配置更新所有已创建 logger 的日志级别。
|
||||
"""
|
||||
log_level_str = log_level.lower()
|
||||
new_level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
||||
|
||||
updated_count = 0
|
||||
for logger_name, logger_instance in Logger._loggers.items():
|
||||
if logger_instance.level != new_level:
|
||||
logger_instance.setLevel(new_level)
|
||||
# 可选:记录级别变更日志,但注意避免在日志模块内部产生过多日志
|
||||
# print(f"Updated log level for logger '{logger_name}' to {log_level_str.upper()}")
|
||||
updated_count += 1
|
||||
|
||||
|
||||
# 预定义的loggers
|
||||
def get_openai_logger():
|
||||
return Logger.setup_logger("openai")
|
||||
|
||||
|
||||
def get_gemini_logger():
|
||||
return Logger.setup_logger("gemini")
|
||||
|
||||
|
||||
def get_chat_logger():
|
||||
return Logger.setup_logger("chat")
|
||||
|
||||
|
||||
def get_model_logger():
|
||||
return Logger.setup_logger("model")
|
||||
|
||||
|
||||
def get_security_logger():
|
||||
return Logger.setup_logger("security")
|
||||
|
||||
|
||||
def get_key_manager_logger():
|
||||
return Logger.setup_logger("key_manager")
|
||||
|
||||
|
||||
def get_main_logger():
|
||||
return Logger.setup_logger("main")
|
||||
|
||||
|
||||
def get_embeddings_logger():
|
||||
return Logger.setup_logger("embeddings")
|
||||
|
||||
|
||||
def get_request_logger():
|
||||
return Logger.setup_logger("request")
|
||||
|
||||
|
||||
def get_retry_logger():
|
||||
return Logger.setup_logger("retry")
|
||||
|
||||
|
||||
def get_image_create_logger():
|
||||
return Logger.setup_logger("image_create")
|
||||
|
||||
|
||||
def get_exceptions_logger():
|
||||
return Logger.setup_logger("exceptions")
|
||||
|
||||
|
||||
def get_application_logger():
|
||||
return Logger.setup_logger("application")
|
||||
|
||||
|
||||
def get_initialization_logger():
|
||||
return Logger.setup_logger("initialization")
|
||||
|
||||
|
||||
def get_middleware_logger():
|
||||
return Logger.setup_logger("middleware")
|
||||
|
||||
|
||||
def get_routes_logger():
|
||||
return Logger.setup_logger("routes")
|
||||
|
||||
|
||||
def get_config_routes_logger():
|
||||
return Logger.setup_logger("config_routes")
|
||||
|
||||
|
||||
def get_config_logger():
|
||||
return Logger.setup_logger("config")
|
||||
|
||||
|
||||
def get_database_logger():
|
||||
return Logger.setup_logger("database")
|
||||
|
||||
|
||||
def get_log_routes_logger():
|
||||
return Logger.setup_logger("log_routes")
|
||||
|
||||
|
||||
def get_stats_logger():
|
||||
return Logger.setup_logger("stats")
|
||||
|
||||
|
||||
def get_update_logger():
|
||||
return Logger.setup_logger("update_service")
|
||||
|
||||
|
||||
def get_scheduler_routes():
|
||||
return Logger.setup_logger("scheduler_routes")
|
||||
|
||||
|
||||
def get_message_converter_logger():
|
||||
return Logger.setup_logger("message_converter")
|
||||
|
||||
|
||||
def get_api_client_logger():
|
||||
return Logger.setup_logger("api_client")
|
||||
|
||||
|
||||
def get_openai_compatible_logger():
|
||||
return Logger.setup_logger("openai_compatible")
|
||||
|
||||
|
||||
def get_error_log_logger():
|
||||
return Logger.setup_logger("error_log")
|
||||
|
||||
131
app/main.py
@@ -1,134 +1,11 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.core.logger import get_main_logger
|
||||
from app.core.security import verify_auth_token
|
||||
from app.services.key_manager import get_key_manager_instance
|
||||
from app.core.config import settings
|
||||
|
||||
from app.api import gemini_routes, openai_routes
|
||||
import uvicorn
|
||||
|
||||
from app.core.application import create_app
|
||||
from app.log.logger import get_main_logger
|
||||
|
||||
# 配置日志
|
||||
logger = get_main_logger()
|
||||
app = create_app()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 配置Jinja2模板
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
# 配置静态文件
|
||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||
|
||||
# 创建 KeyManager 实例
|
||||
key_manager = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global key_manager
|
||||
logger.info("Application starting up...")
|
||||
try:
|
||||
key_manager = await get_key_manager_instance(settings.API_KEYS)
|
||||
logger.info("KeyManager initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize KeyManager: {str(e)}")
|
||||
raise
|
||||
|
||||
# 添加中间件来处理未经身份验证的请求
|
||||
@app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
# 允许 gemini_routes 和 openai_routes 中的端点绕过身份验证
|
||||
if (request.url.path not in ["/", "/auth"] and
|
||||
not request.url.path.startswith("/static") and
|
||||
not request.url.path.startswith("/gemini") and
|
||||
not request.url.path.startswith("/v1") and
|
||||
not request.url.path.startswith("/v1beta") and
|
||||
not request.url.path.startswith("/health") and
|
||||
not request.url.path.startswith("/hf")):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
||||
return RedirectResponse(url="/")
|
||||
logger.debug("Request authenticated successfully")
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# 添加请求日志中间件
|
||||
# app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
|
||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||
expose_headers=["*"], # 允许前端访问的响应头
|
||||
max_age=600, # 预检请求缓存时间(秒)
|
||||
)
|
||||
|
||||
# 包含所有路由
|
||||
app.include_router(openai_routes.router)
|
||||
app.include_router(gemini_routes.router)
|
||||
app.include_router(gemini_routes.router_v1beta)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def auth_page(request: Request):
|
||||
return templates.TemplateResponse("auth.html", {"request": request})
|
||||
|
||||
|
||||
@app.post("/auth")
|
||||
async def authenticate(request: Request):
|
||||
try:
|
||||
form = await request.form()
|
||||
auth_token = form.get("auth_token")
|
||||
if not auth_token:
|
||||
logger.warning("Authentication attempt with empty token")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
if verify_auth_token(auth_token):
|
||||
logger.info("Successful authentication")
|
||||
response = RedirectResponse(url="/keys", status_code=302)
|
||||
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
|
||||
return response
|
||||
logger.warning("Failed authentication attempt with invalid token")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {str(e)}")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
@app.get("/keys", response_class=HTMLResponse)
|
||||
async def keys_page(request: Request):
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to keys page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
|
||||
return templates.TemplateResponse("keys_status.html", {
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"total": total
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving keys status: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(request: Request):
|
||||
logger.info("Health check endpoint called")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_main_logger()
|
||||
logger.info("Starting application server...")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
||||
75
app/middleware/middleware.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
中间件配置模块,负责设置和配置应用程序的中间件
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_middleware_logger
|
||||
|
||||
logger = get_middleware_logger()
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
认证中间件,处理未经身份验证的请求
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# 允许特定路径绕过身份验证
|
||||
if (
|
||||
request.url.path not in ["/", "/auth"]
|
||||
and not request.url.path.startswith("/static")
|
||||
and not request.url.path.startswith("/gemini")
|
||||
and not request.url.path.startswith("/v1")
|
||||
and not request.url.path.startswith(f"/{API_VERSION}")
|
||||
and not request.url.path.startswith("/health")
|
||||
and not request.url.path.startswith("/hf")
|
||||
and not request.url.path.startswith("/openai")
|
||||
and not request.url.path.startswith("/api/version/check")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
||||
return RedirectResponse(url="/")
|
||||
logger.debug("Request authenticated successfully")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
def setup_middlewares(app: FastAPI) -> None:
|
||||
"""
|
||||
设置应用程序的中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
# 添加认证中间件
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# 添加请求日志中间件(可选,默认注释掉)
|
||||
# app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # 明确指定允许的HTTP方法
|
||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||
expose_headers=["*"], # 允许前端访问的响应头
|
||||
max_age=600, # 预检请求缓存时间(秒)
|
||||
)
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import json
|
||||
from app.core.logger import get_request_logger
|
||||
|
||||
from app.log.logger import get_request_logger
|
||||
|
||||
logger = get_request_logger()
|
||||
|
||||
@@ -20,9 +22,11 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
# 尝试格式化JSON
|
||||
try:
|
||||
formatted_body = json.loads(body_str)
|
||||
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
|
||||
logger.info(
|
||||
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.info("Request body is not valid JSON.")
|
||||
logger.error("Request body is not valid JSON.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading request body: {str(e)}")
|
||||
|
||||
|
||||
142
app/router/config_routes.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
配置路由模块
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import Logger, get_config_routes_logger
|
||||
from app.service.config.config_service import ConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
logger = get_config_routes_logger()
|
||||
|
||||
|
||||
@router.get("", response_model=Dict[str, Any])
|
||||
async def get_config(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to config page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
return await ConfigService.get_config()
|
||||
|
||||
|
||||
@router.put("", response_model=Dict[str, Any])
|
||||
async def update_config(config_data: Dict[str, Any], request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to config page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
try:
|
||||
result = await ConfigService.update_config(config_data)
|
||||
# 配置更新成功后,立即更新所有 logger 的级别
|
||||
Logger.update_log_levels(config_data["LOG_LEVEL"])
|
||||
logger.info("Log levels updated after configuration change.")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating config or log levels: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/reset", response_model=Dict[str, Any])
|
||||
async def reset_config(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to config page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
try:
|
||||
return await ConfigService.reset_config()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# Pydantic model for bulk delete request
|
||||
class DeleteKeysRequest(BaseModel):
|
||||
keys: List[str] = Field(..., description="List of API keys to delete")
|
||||
|
||||
|
||||
@router.delete("/keys/{key_to_delete}", response_model=Dict[str, Any])
|
||||
async def delete_single_key(key_to_delete: str, request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized attempt to delete key: {key_to_delete}")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
try:
|
||||
logger.info(f"Attempting to delete key: {key_to_delete}")
|
||||
result = await ConfigService.delete_key(key_to_delete)
|
||||
if not result.get("success"):
|
||||
# Optionally, translate specific errors to HTTP status codes
|
||||
# For now, let's assume 400 for any failure from service if not found,
|
||||
# or 500 if it was an unexpected error (though service should handle that)
|
||||
raise HTTPException(
|
||||
status_code=(
|
||||
404 if "not found" in result.get("message", "").lower() else 400
|
||||
),
|
||||
detail=result.get("message"),
|
||||
)
|
||||
return result
|
||||
except HTTPException as e:
|
||||
# Re-raise HTTPExceptions directly
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting key '{key_to_delete}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/keys/delete-selected", response_model=Dict[str, Any])
|
||||
async def delete_selected_keys_route(
|
||||
delete_request: DeleteKeysRequest, request: Request
|
||||
):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized attempt to bulk delete keys")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
if not delete_request.keys:
|
||||
logger.warning("Attempt to bulk delete keys with an empty list.")
|
||||
raise HTTPException(status_code=400, detail="No keys provided for deletion.")
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to bulk delete {len(delete_request.keys)} keys.")
|
||||
result = await ConfigService.delete_selected_keys(delete_request.keys)
|
||||
# Similar to single delete, we can check result["success"]
|
||||
if not result.get("success") and result.get("deleted_count", 0) == 0:
|
||||
# If no keys were actually deleted, it might be a client error (e.g., all keys not found)
|
||||
# or an empty list was somehow passed despite the check above.
|
||||
raise HTTPException(
|
||||
status_code=400, detail=result.get("message", "Failed to delete keys.")
|
||||
)
|
||||
# If some keys were deleted but others not found, it's still a partial success, return 200 with details.
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error bulk deleting keys: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error bulk deleting keys: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ui/models")
|
||||
async def get_ui_models(request: Request):
|
||||
auth_token_cookie = request.cookies.get("auth_token")
|
||||
if not auth_token_cookie or not verify_auth_token(auth_token_cookie):
|
||||
logger.warning("Unauthorized access attempt to /api/config/ui/models")
|
||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
models = await ConfigService.fetch_ui_models()
|
||||
return models
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in /ui/models endpoint: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching UI models: {str(e)}",
|
||||
)
|
||||
211
app/router/error_log_routes.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
日志路由模块
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_log_routes_logger
|
||||
from app.service.error_log import error_log_service
|
||||
|
||||
router = APIRouter(prefix="/api/logs", tags=["logs"])
|
||||
|
||||
logger = get_log_routes_logger()
|
||||
|
||||
|
||||
class ErrorLogListItem(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class ErrorLogListResponse(BaseModel):
|
||||
logs: List[ErrorLogListItem]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/errors", response_model=ErrorLogListResponse)
|
||||
async def get_error_logs_api(
|
||||
request: Request,
|
||||
limit: int = Query(10, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
key_search: Optional[str] = Query(
|
||||
None, description="Search term for Gemini key (partial match)"
|
||||
),
|
||||
error_search: Optional[str] = Query(
|
||||
None, description="Search term for error type or log message"
|
||||
),
|
||||
error_code_search: Optional[str] = Query(
|
||||
None, description="Search term for error code"
|
||||
),
|
||||
start_date: Optional[datetime] = Query(
|
||||
None, description="Start datetime for filtering"
|
||||
),
|
||||
end_date: Optional[datetime] = Query(
|
||||
None, description="End datetime for filtering"
|
||||
),
|
||||
sort_by: str = Query(
|
||||
"id", description="Field to sort by (e.g., 'id', 'request_time')"
|
||||
),
|
||||
sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
|
||||
):
|
||||
"""
|
||||
获取错误日志列表 (返回错误码),支持过滤和排序
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
key_search: 密钥搜索
|
||||
error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定)
|
||||
error_code_search: 错误码搜索
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
sort_by: 排序字段
|
||||
sort_order: 排序顺序
|
||||
|
||||
Returns:
|
||||
ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count.
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to error logs list")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
result = await error_log_service.process_get_error_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
logs_data = result["logs"]
|
||||
total_count = result["total"]
|
||||
|
||||
validated_logs = [ErrorLogListItem(**log) for log in logs_data]
|
||||
return ErrorLogListResponse(logs=validated_logs, total=total_count)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error logs list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get error logs list: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class ErrorLogDetailResponse(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_log: Optional[str] = None
|
||||
request_msg: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
|
||||
@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
|
||||
async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
"""
|
||||
根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(
|
||||
f"Unauthorized access attempt to error log details for ID: {log_id}"
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
log_details = await error_log_service.process_get_error_log_details(
|
||||
log_id=log_id
|
||||
)
|
||||
if not log_details:
|
||||
raise HTTPException(status_code=404, detail="Error log not found")
|
||||
|
||||
return ErrorLogDetailResponse(**log_details)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get error log details: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_logs_bulk_api(
|
||||
request: Request, payload: Dict[str, List[int]] = Body(...)
|
||||
):
|
||||
"""
|
||||
批量删除错误日志 (异步)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to bulk delete error logs")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
log_ids = payload.get("ids")
|
||||
if not log_ids:
|
||||
raise HTTPException(status_code=400, detail="No log IDs provided for deletion.")
|
||||
|
||||
try:
|
||||
deleted_count = await error_log_service.process_delete_error_logs_by_ids(
|
||||
log_ids
|
||||
)
|
||||
# 注意:异步函数返回的是尝试删除的数量,可能不是精确值
|
||||
logger.info(
|
||||
f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}"
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error bulk deleting error logs with IDs {log_ids}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error during bulk deletion"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
"""
|
||||
删除单个错误日志 (异步)
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
success = await error_log_service.process_delete_error_log_by_id(log_id)
|
||||
if not success:
|
||||
# 服务层现在在未找到时返回 False,我们在这里转换为 404
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Error log with ID {log_id} not found"
|
||||
)
|
||||
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error during deletion"
|
||||
)
|
||||
374
app/router/gemini_routes.py
Normal file
@@ -0,0 +1,374 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
|
||||
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||||
logger = get_gemini_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
"""获取密钥管理器实例"""
|
||||
return await get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取下一个可用的API密钥"""
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
|
||||
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取Gemini聊天服务实例"""
|
||||
return GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@router_v1beta.get("/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||||
logger.info("Handling Gemini models list request")
|
||||
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
models_json["models"].append(item)
|
||||
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
|
||||
logger.info("Gemini models list request successful")
|
||||
return models_json
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
@router_v1beta.post("/models/{model_name}:generateContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
||||
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
|
||||
logger.info(f"Received reset request with key_type: {key_type}")
|
||||
|
||||
try:
|
||||
# 获取分类后的密钥
|
||||
keys_by_status = await key_manager.get_keys_by_status()
|
||||
valid_keys = keys_by_status.get("valid_keys", {})
|
||||
invalid_keys = keys_by_status.get("invalid_keys", {})
|
||||
|
||||
# 根据类型选择要重置的密钥
|
||||
keys_to_reset = []
|
||||
if key_type == "valid":
|
||||
keys_to_reset = list(valid_keys.keys())
|
||||
logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}")
|
||||
elif key_type == "invalid":
|
||||
keys_to_reset = list(invalid_keys.keys())
|
||||
logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}")
|
||||
else:
|
||||
# 重置所有密钥
|
||||
await key_manager.reset_failure_counts()
|
||||
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
|
||||
|
||||
# 批量重置指定类型的密钥
|
||||
for key in keys_to_reset:
|
||||
await key_manager.reset_key_failure_count(key)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"{key_type}密钥的失败计数已重置",
|
||||
"reset_count": len(keys_to_reset)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset key failure counts: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
@router.post("/reset-selected-fail-counts")
|
||||
async def reset_selected_key_fail_counts(
|
||||
request: ResetSelectedKeysRequest,
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""批量重置选定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
||||
keys_to_reset = request.keys
|
||||
key_type = request.key_type
|
||||
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
||||
|
||||
if not keys_to_reset:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
|
||||
|
||||
reset_count = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
for key in keys_to_reset:
|
||||
try:
|
||||
result = await key_manager.reset_key_failure_count(key)
|
||||
if result:
|
||||
reset_count += 1
|
||||
else:
|
||||
logger.warning(f"Key not found during selective reset: {key}")
|
||||
except Exception as key_error:
|
||||
logger.error(f"Error resetting key {key}: {str(key_error)}")
|
||||
errors.append(f"Key {key}: {str(key_error)}")
|
||||
|
||||
if errors:
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse({
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count
|
||||
}, status_code=status_code)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
@router.post("/reset-fail-count/{api_key}")
|
||||
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""重置指定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
|
||||
logger.info(f"Resetting failure count for API key: {api_key}")
|
||||
|
||||
try:
|
||||
result = await key_manager.reset_key_failure_count(api_key)
|
||||
if result:
|
||||
return JSONResponse({"success": True, "message": "失败计数已重置"})
|
||||
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset key failure count: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
@router.post("/verify-key/{api_key}")
|
||||
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""验证Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||||
logger.info("Verifying API key validity")
|
||||
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=[{"text": "hi"}],
|
||||
)
|
||||
],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
)
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
)
|
||||
|
||||
if response:
|
||||
return JSONResponse({"status": "valid"})
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Verification exception for key: {api_key}, incrementing failure count")
|
||||
|
||||
return JSONResponse({"status": "invalid", "error": str(e)})
|
||||
|
||||
|
||||
@router.post("/verify-selected-keys")
|
||||
async def verify_selected_keys(
|
||||
request: VerifySelectedKeysRequest,
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""批量验证选定Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
|
||||
keys_to_verify = request.keys
|
||||
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
|
||||
|
||||
if not keys_to_verify:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
||||
|
||||
successful_keys = []
|
||||
failed_keys = {}
|
||||
|
||||
async def _verify_single_key(api_key: str):
|
||||
"""内部函数,用于验证单个密钥并处理异常"""
|
||||
nonlocal successful_keys, failed_keys
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
)
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
)
|
||||
successful_keys.append(api_key)
|
||||
return api_key, "valid", None
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.warning(f"Key verification failed for {api_key}: {error_message}")
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
|
||||
else:
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
|
||||
failed_keys[api_key] = error_message
|
||||
return api_key, "invalid", error_message
|
||||
|
||||
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
||||
elif result:
|
||||
if not isinstance(result, Exception) and result:
|
||||
key, status, error = result
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(f"Task execution error during bulk verification: {result}")
|
||||
|
||||
valid_count = len(successful_keys)
|
||||
invalid_count = len(failed_keys)
|
||||
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
||||
|
||||
if failed_keys:
|
||||
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count
|
||||
})
|
||||
else:
|
||||
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": {},
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": 0
|
||||
})
|
||||
113
app/router/openai_compatiable_routes.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_next_working_key_wrapper(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
|
||||
async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取OpenAI聊天服务实例"""
|
||||
return OpenAICompatiableService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/openai/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""获取可用模型列表。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await openai_service.get_models(api_key)
|
||||
|
||||
|
||||
@router.post("/openai/v1/chat/completions")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key()
|
||||
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
if is_image_chat:
|
||||
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
||||
return response
|
||||
else:
|
||||
response = await openai_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/openai/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
request.model = settings.CREATE_IMAGE_MODEL
|
||||
return await openai_service.generate_images(request)
|
||||
|
||||
|
||||
@router.post("/openai/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理文本嵌入请求。"""
|
||||
operation_name = "embedding"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await openai_service.create_embeddings(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
149
app/router/openai_routes.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
embedding_service = EmbeddingService()
|
||||
image_create_service = ImageCreateService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_next_working_key_wrapper(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
|
||||
async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取OpenAI聊天服务实例"""
|
||||
return OpenAIChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return await model_service.get_gemini_openai_models(api_key)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/hf/v1/chat/completions")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
||||
):
|
||||
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key()
|
||||
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
if not await model_service.check_model_support(request.model):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
if is_image_chat:
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
else:
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@router.post("/hf/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
):
|
||||
"""处理 OpenAI 图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
response = image_create_service.generate_images(request)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""处理 OpenAI 文本嵌入请求。"""
|
||||
operation_name = "embedding"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/v1/keys/list")
|
||||
@router.get("/hf/v1/keys/list")
|
||||
async def get_keys_list(
|
||||
_=Depends(security_service.verify_auth_token),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
|
||||
operation_name = "get_keys_list"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling keys list request")
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
return {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
||||
}
|
||||
186
app/router/routes.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
路由配置模块,负责设置和配置应用程序的路由
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats.stats_service import StatsService
|
||||
|
||||
logger = get_routes_logger()
|
||||
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
def setup_routers(app: FastAPI) -> None:
|
||||
"""
|
||||
设置应用程序的路由
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
app.include_router(openai_routes.router)
|
||||
app.include_router(gemini_routes.router)
|
||||
app.include_router(gemini_routes.router_v1beta)
|
||||
app.include_router(config_routes.router)
|
||||
app.include_router(error_log_routes.router)
|
||||
app.include_router(scheduler_routes.router)
|
||||
app.include_router(stats_routes.router)
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
|
||||
setup_page_routes(app)
|
||||
|
||||
setup_health_routes(app)
|
||||
setup_api_stats_routes(app)
|
||||
|
||||
|
||||
def setup_page_routes(app: FastAPI) -> None:
|
||||
"""
|
||||
设置页面相关的路由
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def auth_page(request: Request):
|
||||
"""认证页面"""
|
||||
return templates.TemplateResponse("auth.html", {"request": request})
|
||||
|
||||
@app.post("/auth")
|
||||
async def authenticate(request: Request):
|
||||
"""处理认证请求"""
|
||||
try:
|
||||
form = await request.form()
|
||||
auth_token = form.get("auth_token")
|
||||
if not auth_token:
|
||||
logger.warning("Authentication attempt with empty token")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
if verify_auth_token(auth_token):
|
||||
logger.info("Successful authentication")
|
||||
response = RedirectResponse(url="/config", status_code=302)
|
||||
response.set_cookie(
|
||||
key="auth_token", value=auth_token, httponly=True, max_age=3600
|
||||
)
|
||||
return response
|
||||
logger.warning("Failed authentication attempt with invalid token")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {str(e)}")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
@app.get("/keys", response_class=HTMLResponse)
|
||||
async def keys_page(request: Request):
|
||||
"""密钥管理页面"""
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to keys page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
key_manager = await get_key_manager_instance()
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
total_keys = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
valid_key_count = len(keys_status["valid_keys"])
|
||||
invalid_key_count = len(keys_status["invalid_keys"])
|
||||
|
||||
stats_service = StatsService()
|
||||
api_stats = await stats_service.get_api_usage_stats()
|
||||
logger.info(f"API stats retrieved: {api_stats}")
|
||||
|
||||
logger.info(f"Keys status retrieved successfully. Total keys: {total_keys}")
|
||||
return templates.TemplateResponse(
|
||||
"keys_status.html",
|
||||
{
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"total_keys": total_keys,
|
||||
"valid_key_count": valid_key_count,
|
||||
"invalid_key_count": invalid_key_count,
|
||||
"api_stats": api_stats,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving keys status or API stats: {str(e)}")
|
||||
raise
|
||||
|
||||
@app.get("/config", response_class=HTMLResponse)
|
||||
async def config_page(request: Request):
|
||||
"""配置编辑页面"""
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to config page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
logger.info("Config page accessed successfully")
|
||||
return templates.TemplateResponse("config_editor.html", {"request": request})
|
||||
except Exception as e:
|
||||
logger.error(f"Error accessing config page: {str(e)}")
|
||||
raise
|
||||
|
||||
@app.get("/logs", response_class=HTMLResponse)
|
||||
async def logs_page(request: Request):
|
||||
"""错误日志页面"""
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to logs page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
logger.info("Logs page accessed successfully")
|
||||
return templates.TemplateResponse("error_logs.html", {"request": request})
|
||||
except Exception as e:
|
||||
logger.error(f"Error accessing logs page: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def setup_health_routes(app: FastAPI) -> None:
|
||||
"""
|
||||
设置健康检查相关的路由
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(request: Request):
|
||||
"""健康检查端点"""
|
||||
logger.info("Health check endpoint called")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
def setup_api_stats_routes(app: FastAPI) -> None:
|
||||
"""
|
||||
设置 API 统计相关的路由
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
@app.get("/api/stats/details")
|
||||
async def api_stats_details(request: Request, period: str):
|
||||
"""获取指定时间段内的 API 调用详情"""
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to API stats details")
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
logger.info(f"Fetching API call details for period: {period}")
|
||||
stats_service = StatsService()
|
||||
details = await stats_service.get_api_call_details(period)
|
||||
return details
|
||||
except ValueError as e:
|
||||
logger.warning(f"Invalid period requested for API stats details: {period} - {str(e)}")
|
||||
return {"error": str(e)}, 400
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching API stats details for period {period}: {str(e)}")
|
||||
return {"error": "Internal server error"}, 500
|
||||
57
app/router/scheduler_routes.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
定时任务控制路由模块
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
||||
from app.log.logger import get_scheduler_routes
|
||||
|
||||
logger = get_scheduler_routes()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/scheduler",
|
||||
tags=["Scheduler"]
|
||||
)
|
||||
|
||||
async def verify_token(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to scheduler API")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
@router.post("/start", summary="启动定时任务")
|
||||
async def start_scheduler_endpoint(request: Request):
|
||||
"""Start the background scheduler task"""
|
||||
await verify_token(request)
|
||||
try:
|
||||
logger.info("Received request to start scheduler.")
|
||||
start_scheduler()
|
||||
return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting scheduler: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to start scheduler: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/stop", summary="停止定时任务")
|
||||
async def stop_scheduler_endpoint(request: Request):
|
||||
"""Stop the background scheduler task"""
|
||||
await verify_token(request)
|
||||
try:
|
||||
logger.info("Received request to stop scheduler.")
|
||||
stop_scheduler()
|
||||
return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to stop scheduler: {str(e)}"
|
||||
)
|
||||
55
app/router/stats_routes.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from starlette import status
|
||||
from app.core.security import verify_auth_token
|
||||
from app.service.stats.stats_service import StatsService
|
||||
from app.log.logger import get_stats_logger
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
|
||||
async def verify_token(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to scheduler API")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api",
|
||||
tags=["stats"],
|
||||
dependencies=[Depends(verify_token)]
|
||||
)
|
||||
|
||||
stats_service = StatsService()
|
||||
|
||||
@router.get("/key-usage-details/{key}",
|
||||
summary="获取指定密钥最近24小时的模型调用次数",
|
||||
description="根据提供的 API 密钥,返回过去24小时内每个模型被调用的次数统计。")
|
||||
async def get_key_usage_details(key: str):
|
||||
"""
|
||||
Retrieves the model usage count for a specific API key within the last 24 hours.
|
||||
|
||||
Args:
|
||||
key: The API key to get usage details for.
|
||||
|
||||
Returns:
|
||||
A dictionary with model names as keys and their call counts as values.
|
||||
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during data retrieval.
|
||||
"""
|
||||
try:
|
||||
usage_details = await stats_service.get_key_usage_details_last_24h(key)
|
||||
if usage_details is None:
|
||||
return {}
|
||||
return usage_details
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching key usage details for key {key[:4]}...: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取密钥使用详情时出错: {e}"
|
||||
)
|
||||
37
app/router/version_routes.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
from app.service.update.update_service import check_for_updates
|
||||
from app.utils.helpers import get_current_version
|
||||
from app.log.logger import get_update_logger
|
||||
|
||||
router = APIRouter(prefix="/api/version", tags=["Version"])
|
||||
logger = get_update_logger()
|
||||
|
||||
class VersionInfo(BaseModel):
|
||||
current_version: str = Field(..., description="当前应用程序版本")
|
||||
latest_version: Optional[str] = Field(None, description="可用的最新版本")
|
||||
update_available: bool = Field(False, description="是否有可用更新")
|
||||
error_message: Optional[str] = Field(None, description="检查更新时发生的错误信息")
|
||||
|
||||
@router.get("/check", response_model=VersionInfo, summary="检查应用程序更新")
|
||||
async def get_version_info():
|
||||
"""
|
||||
检查当前应用程序版本与最新的 GitHub release 版本。
|
||||
"""
|
||||
try:
|
||||
current_version = get_current_version()
|
||||
update_available, latest_version, error_message = await check_for_updates()
|
||||
|
||||
logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'")
|
||||
|
||||
return VersionInfo(
|
||||
current_version=current_version,
|
||||
latest_version=latest_version,
|
||||
update_available=update_available,
|
||||
error_message=error_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in /api/version/check endpoint: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="检查版本信息时发生内部错误")
|
||||
162
app/scheduler/scheduled_tasks.py
Normal file
@@ -0,0 +1,162 @@
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||
from app.log.logger import Logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.error_log.error_log_service import delete_old_error_logs
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
||||
|
||||
logger = Logger.setup_logger("scheduler")
|
||||
|
||||
|
||||
async def check_failed_keys():
|
||||
"""
|
||||
定时检查失败次数大于0的API密钥,并尝试验证它们。
|
||||
如果验证成功,重置失败计数;如果失败,增加失败计数。
|
||||
"""
|
||||
logger.info("Starting scheduled check for failed API keys...")
|
||||
try:
|
||||
key_manager = await get_key_manager_instance()
|
||||
# 确保 KeyManager 已经初始化
|
||||
if not key_manager or not hasattr(key_manager, "key_failure_counts"):
|
||||
logger.warning(
|
||||
"KeyManager instance not available or not initialized. Skipping check."
|
||||
)
|
||||
return
|
||||
|
||||
# 创建 GeminiChatService 实例用于验证
|
||||
# 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务
|
||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
# 获取需要检查的 key 列表 (失败次数 > 0)
|
||||
keys_to_check = []
|
||||
async with key_manager.failure_count_lock: # 访问共享数据需要加锁
|
||||
# 复制一份以避免在迭代时修改字典
|
||||
failure_counts_copy = key_manager.key_failure_counts.copy()
|
||||
keys_to_check = [
|
||||
key for key, count in failure_counts_copy.items() if count > 0
|
||||
] # 检查所有失败次数大于0的key
|
||||
|
||||
if not keys_to_check:
|
||||
logger.info("No keys with failure count > 0 found. Skipping verification.")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Found {len(keys_to_check)} keys with failure count > 0 to verify."
|
||||
)
|
||||
|
||||
for key in keys_to_check:
|
||||
# 隐藏部分 key 用于日志记录
|
||||
log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key
|
||||
logger.info(f"Verifying key: {log_key}...")
|
||||
try:
|
||||
# 构造测试请求
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=[{"text": "hi"}], # 使用简单的文本进行验证
|
||||
)
|
||||
]
|
||||
)
|
||||
# 调用 generate_content 进行验证
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL, gemini_request, key # 使用配置中定义的测试模型
|
||||
)
|
||||
# 如果没有抛出异常,说明 key 有效
|
||||
logger.info(
|
||||
f"Key {log_key} verification successful. Resetting failure count."
|
||||
)
|
||||
await key_manager.reset_key_failure_count(key)
|
||||
except Exception as e:
|
||||
# 验证失败,增加失败计数
|
||||
logger.warning(
|
||||
f"Key {log_key} verification failed: {str(e)}. Incrementing failure count."
|
||||
)
|
||||
# 直接操作计数器,需要加锁
|
||||
async with key_manager.failure_count_lock:
|
||||
# 再次检查 key 是否存在且失败次数未达上限
|
||||
if (
|
||||
key in key_manager.key_failure_counts
|
||||
and key_manager.key_failure_counts[key]
|
||||
< key_manager.MAX_FAILURES
|
||||
):
|
||||
key_manager.key_failure_counts[key] += 1
|
||||
logger.info(
|
||||
f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}."
|
||||
)
|
||||
elif key in key_manager.key_failure_counts:
|
||||
logger.warning(
|
||||
f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled key check: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
# 添加检查失败密钥的定时任务
|
||||
scheduler.add_job(
|
||||
check_failed_keys,
|
||||
"interval",
|
||||
hours=settings.CHECK_INTERVAL_HOURS,
|
||||
id="check_failed_keys_job",
|
||||
name="Check Failed API Keys",
|
||||
)
|
||||
logger.info(
|
||||
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
||||
)
|
||||
|
||||
# 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
|
||||
scheduler.add_job(
|
||||
delete_old_error_logs,
|
||||
"cron",
|
||||
hour=3,
|
||||
minute=0,
|
||||
id="delete_old_error_logs_job",
|
||||
name="Delete Old Error Logs",
|
||||
)
|
||||
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
|
||||
|
||||
# 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
|
||||
scheduler.add_job(
|
||||
delete_old_request_logs_task,
|
||||
"cron",
|
||||
hour=3,
|
||||
minute=5,
|
||||
id="delete_old_request_logs_job",
|
||||
name="Delete Old Request Logs",
|
||||
)
|
||||
logger.info(
|
||||
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started with all jobs.")
|
||||
return scheduler
|
||||
|
||||
|
||||
# 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止
|
||||
scheduler_instance = None
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
global scheduler_instance
|
||||
if scheduler_instance is None or not scheduler_instance.running:
|
||||
logger.info("Starting scheduler...")
|
||||
scheduler_instance = setup_scheduler()
|
||||
logger.info("Scheduler is already running.")
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
global scheduler_instance
|
||||
if scheduler_instance and scheduler_instance.running:
|
||||
scheduler_instance.shutdown()
|
||||
logger.info("Scheduler stopped.")
|
||||
@@ -1,40 +0,0 @@
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SafetySetting(BaseModel):
|
||||
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
|
||||
threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
stopSequences: Optional[List[str]] = None
|
||||
responseMimeType: Optional[str] = None
|
||||
responseSchema: Optional[Dict[str, Any]] = None
|
||||
candidateCount: Optional[int] = 1
|
||||
maxOutputTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
topK: Optional[int] = None
|
||||
presencePenalty: Optional[float] = None
|
||||
frequencyPenalty: Optional[float] = None
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
role: str = "system"
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
role: str
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiRequest(BaseModel):
|
||||
contents: List[GeminiContent] = []
|
||||
tools: Optional[List[Dict[str, Any]]] = []
|
||||
safetySettings: Optional[List[SafetySetting]] = None
|
||||
generationConfig: Optional[GenerationConfig] = {}
|
||||
systemInstruction: Optional[SystemInstruction] = None
|
||||
@@ -1,30 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str = "gemini-1.5-flash-002"
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
max_tokens: Optional[int] = 8192
|
||||
stop: Optional[List[str]] = []
|
||||
top_p: Optional[float] = 0.9
|
||||
top_k: Optional[int] = 40
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-004"
|
||||
encoding_format: Optional[str] = "float"
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = "DALL-E-3"
|
||||
prompt: str = ""
|
||||
n: int = 1
|
||||
size: Optional[str] = "1024x1024"
|
||||
quality: Optional[str] = ""
|
||||
style: Optional[str] = ""
|
||||
response_format: Optional[str] = "url"
|
||||
284
app/service/chat/gemini_chat_service.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import json
|
||||
import re
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片部分"""
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
|
||||
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
record = dict()
|
||||
for item in tools:
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
functions.extend(v)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
return record
|
||||
|
||||
tool = dict()
|
||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
||||
payload["tools"] = [payload.get("tools")]
|
||||
items = payload.get("tools", [])
|
||||
if items and isinstance(items, list):
|
||||
tool.update(_merge_tools(items))
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class GeminiChatService:
|
||||
"""聊天服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager):
|
||||
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
||||
self.key_manager = key_manager
|
||||
self.response_handler = GeminiResponseHandler()
|
||||
|
||||
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
||||
"""从响应中提取文本内容"""
|
||||
if not response.get("candidates"):
|
||||
return ""
|
||||
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts and "text" in parts[0]:
|
||||
return parts[0].get("text", "")
|
||||
return ""
|
||||
|
||||
def _create_char_response(
|
||||
self, original_response: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的响应"""
|
||||
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
||||
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
||||
"content", {}
|
||||
).get("parts"):
|
||||
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
||||
return response_copy
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
payload = _build_payload(model, request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now() # Record request time
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 on success
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
# Try to parse status code from exception
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default to 500 if parsing fails
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
raise e # Re-throw exception for upstream handling
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
# Log request to request log table
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
payload = _build_payload(model, request)
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
|
||||
while retries < max_retries:
|
||||
request_datetime = datetime.datetime.now()
|
||||
start_time = time.perf_counter()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key # Update final key used
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, current_attempt_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
line = line[6:]
|
||||
response_data = self.response_handler.handle_response(
|
||||
json.loads(line), model, stream=True
|
||||
)
|
||||
text = self._extract_text_from_response(response_data)
|
||||
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
|
||||
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in gemini_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_response(response_data, t),
|
||||
lambda c: "data: " + json.dumps(c) + "\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如工具调用等),整块输出
|
||||
yield "data: " + json.dumps(response_data) + "\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
is_success = True
|
||||
status_code = 200
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
# Parse error code for logging
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key, # Log key used for this failed attempt
|
||||
model_name=model,
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
|
||||
# Attempt to switch API Key
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else: # No more keys or retries exceeded by handle_api_failure logic
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break # Exit loop if no key available
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
)
|
||||
break # Exit loop after max retries
|
||||
finally:
|
||||
# Log the final outcome of the streaming request
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key, # Log the last key used
|
||||
is_success=is_success, # Log the final success status
|
||||
status_code=status_code, # Log the last known status code
|
||||
latency_ms=latency_ms, # Log total time including retries
|
||||
request_time=request_datetime
|
||||
)
|
||||
605
app/service/chat/openai_chat_service.py
Normal file
@@ -0,0 +1,605 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.handler.message_converter import OpenAIMessageConverter
|
||||
from app.handler.response_handler import OpenAIResponseHandler
|
||||
from app.handler.stream_optimizer import openai_optimizer
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.key.key_manager import KeyManager
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
|
||||
for content in contents:
|
||||
if content and "parts" in content and isinstance(content["parts"], list):
|
||||
for part in content["parts"]:
|
||||
if isinstance(part, dict) and "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_tools(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
tool = dict()
|
||||
model = request.model
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (
|
||||
model.endswith("-search")
|
||||
or "-thinking" in model
|
||||
or model.endswith("-image")
|
||||
or model.endswith("-image-generation")
|
||||
)
|
||||
and not _has_media_parts(messages) # Use the updated check
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
logger.debug("Code execution tool enabled.")
|
||||
elif _has_media_parts(messages):
|
||||
logger.debug("Code execution tool disabled due to media parts presence.")
|
||||
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
# 将 request 中的 tools 合并到 tools 中
|
||||
if request.tools:
|
||||
function_declarations = []
|
||||
for item in request.tools:
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if item.get("type", "") == "function" and item.get("function"):
|
||||
function = deepcopy(item.get("function"))
|
||||
parameters = function.get("parameters", {})
|
||||
if parameters.get("type") == "object" and not parameters.get(
|
||||
"properties", {}
|
||||
):
|
||||
function.pop("parameters", None)
|
||||
|
||||
function_declarations.append(function)
|
||||
|
||||
if function_declarations:
|
||||
# 按照 function 的 name 去重
|
||||
names, functions = set(), []
|
||||
for fc in function_declarations:
|
||||
if fc.get("name") not in names:
|
||||
names.add(fc.get("name"))
|
||||
functions.append(fc)
|
||||
|
||||
tool["functionDeclarations"] = functions
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
# if (
|
||||
# "2.0" in model
|
||||
# and "gemini-2.0-flash-thinking-exp" not in model
|
||||
# and "gemini-2.0-pro-exp" not in model
|
||||
# ):
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest,
|
||||
messages: List[Dict[str, Any]],
|
||||
instruction: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
"stopSequences": request.stop,
|
||||
"topP": request.top_p,
|
||||
"topK": request.top_k,
|
||||
},
|
||||
"tools": _build_tools(request, messages),
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.max_tokens is not None:
|
||||
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
if request.model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if request.model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
|
||||
}
|
||||
|
||||
if (
|
||||
instruction
|
||||
and isinstance(instruction, dict)
|
||||
and instruction.get("role") == "system"
|
||||
and instruction.get("parts")
|
||||
and not request.model.endswith("-image")
|
||||
and not request.model.endswith("-image-generation")
|
||||
):
|
||||
payload["systemInstruction"] = instruction
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class OpenAIChatService:
|
||||
"""聊天服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.message_converter = OpenAIMessageConverter()
|
||||
self.response_handler = OpenAIResponseHandler(config=None)
|
||||
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
||||
self.key_manager = key_manager
|
||||
self.image_create_service = ImageCreateService()
|
||||
|
||||
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
||||
"""从OpenAI响应块中提取文本内容"""
|
||||
if not chunk.get("choices"):
|
||||
return ""
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
if "delta" in choice and "content" in choice["delta"]:
|
||||
return choice["delta"]["content"]
|
||||
return ""
|
||||
|
||||
def _create_char_openai_chunk(
|
||||
self, original_chunk: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的OpenAI响应块"""
|
||||
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
||||
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
||||
chunk_copy["choices"][0]["delta"]["content"] = text
|
||||
return chunk_copy
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
|
||||
# 构建请求payload
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, payload, api_key)
|
||||
return await self._handle_normal_completion(request.model, payload, api_key)
|
||||
|
||||
async def _handle_normal_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
usage_metadata = response.get("usageMetadata", {})
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
# Try to parse status code from exception
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _fake_stream_logic_impl(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理伪流式 (fake stream) 的核心逻辑"""
|
||||
logger.info(
|
||||
f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint."
|
||||
)
|
||||
keep_sending_empty_data = True
|
||||
|
||||
async def send_empty_data_locally() -> AsyncGenerator[str, None]:
|
||||
"""定期发送空数据以保持连接"""
|
||||
while keep_sending_empty_data:
|
||||
await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
|
||||
if keep_sending_empty_data:
|
||||
empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
||||
yield f"data: {json.dumps(empty_chunk)}\n\n"
|
||||
logger.debug("Sent empty data chunk for fake stream heartbeat.")
|
||||
|
||||
empty_data_generator = send_empty_data_locally()
|
||||
api_response_task = asyncio.create_task(
|
||||
self.api_client.generate_content(payload, model, api_key)
|
||||
)
|
||||
|
||||
try:
|
||||
while not api_response_task.done():
|
||||
try:
|
||||
next_empty_chunk = await asyncio.wait_for(
|
||||
empty_data_generator.__anext__(), timeout=0.1
|
||||
)
|
||||
yield next_empty_chunk
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except (
|
||||
StopAsyncIteration
|
||||
):
|
||||
break
|
||||
|
||||
response = await api_response_task
|
||||
finally:
|
||||
keep_sending_empty_data = False
|
||||
|
||||
if response and response.get("candidates"):
|
||||
response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
|
||||
yield f"data: {json.dumps(response)}\n\n"
|
||||
logger.info(f"Sent full response content for fake stream: {model}")
|
||||
else:
|
||||
error_message = "Failed to get response from model"
|
||||
if (
|
||||
response and isinstance(response, dict) and response.get("error")
|
||||
):
|
||||
error_details = response.get("error")
|
||||
if isinstance(error_details, dict):
|
||||
error_message = error_details.get("message", error_message)
|
||||
|
||||
logger.error(
|
||||
f"No candidates or error in response for fake stream model {model}: {response}"
|
||||
)
|
||||
error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
||||
|
||||
async def _real_stream_logic_impl(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理真实流式 (real stream) 的核心逻辑"""
|
||||
tool_call_flag = False
|
||||
usage_metadata = None
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, api_key
|
||||
):
|
||||
if line.startswith("data:"):
|
||||
chunk_str = line[6:]
|
||||
if not chunk_str or chunk_str.isspace():
|
||||
logger.debug(
|
||||
f"Received empty data line for model {model}, skipping."
|
||||
)
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(chunk_str)
|
||||
usage_metadata = chunk.get("usageMetadata", {})
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Failed to decode JSON from stream for model {model}: {chunk_str}"
|
||||
)
|
||||
continue
|
||||
openai_chunk = self.response_handler.handle_response(
|
||||
chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
|
||||
)
|
||||
if openai_chunk:
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
||||
async for (
|
||||
optimized_chunk_data
|
||||
) in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||
):
|
||||
yield optimized_chunk_data
|
||||
else:
|
||||
if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
|
||||
tool_call_flag = True
|
||||
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
|
||||
if tool_call_flag:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n"
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑和假流式支持"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
|
||||
while retries < max_retries:
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
current_attempt_key = final_api_key
|
||||
|
||||
try:
|
||||
stream_generator = None
|
||||
if settings.FAKE_STREAM_ENABLED:
|
||||
logger.info(
|
||||
f"Using fake stream logic for model: {model}, Attempt: {retries + 1}"
|
||||
)
|
||||
stream_generator = self._fake_stream_logic_impl(
|
||||
model, payload, current_attempt_key
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using real stream logic for model: {model}, Attempt: {retries + 1}"
|
||||
)
|
||||
stream_generator = self._real_stream_logic_impl(
|
||||
model, payload, current_attempt_key
|
||||
)
|
||||
|
||||
async for chunk_data in stream_generator:
|
||||
yield chunk_data
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info(
|
||||
f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
|
||||
)
|
||||
|
||||
match = re.search(r"status code (\\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
status_code = 408
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
|
||||
if self.key_manager:
|
||||
new_api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if new_api_key and new_api_key != current_attempt_key:
|
||||
final_api_key = new_api_key
|
||||
logger.info(
|
||||
f"Switched to new API key for next attempt: {final_api_key}"
|
||||
)
|
||||
elif not new_api_key:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries, ceasing attempts for this request."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
|
||||
)
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming model {model}."
|
||||
)
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=current_attempt_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
if not is_success:
|
||||
logger.error(
|
||||
f"Streaming failed permanently for model {model} after {retries} attempts."
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self, request: ChatRequest, api_key: str
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
|
||||
image_generate_request = ImageGenerationRequest()
|
||||
image_generate_request.prompt = request.messages[-1]["content"]
|
||||
image_res = self.image_create_service.generate_images_chat(
|
||||
image_generate_request
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_image_completion(
|
||||
request.model, image_res, api_key
|
||||
)
|
||||
else:
|
||||
return await self._handle_normal_image_completion(
|
||||
request.model, image_res, api_key
|
||||
)
|
||||
|
||||
async def _handle_stream_image_completion(
|
||||
self, model: str, image_data: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
logger.info(f"Starting stream image completion for model: {model}")
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
if image_data:
|
||||
openai_chunk = self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=True, finish_reason=None
|
||||
)
|
||||
if openai_chunk:
|
||||
# 提取文本内容
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如图片URL等),整块输出
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
logger.info(
|
||||
f"Stream image completion finished successfully for model: {model}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
)
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(
|
||||
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_normal_image_completion(
|
||||
self, model: str, image_data: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"Starting normal image completion for model: {model}")
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
result = self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
logger.info(
|
||||
f"Normal image completion finished successfully for model: {model}"
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return result
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
)
|
||||
# Re-raise the exception so the caller knows about the failure
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(
|
||||
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
195
app/service/client/api_client.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# app/services/chat/api_client.py
|
||||
|
||||
from typing import Dict, Any, AsyncGenerator, Optional
|
||||
import httpx
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_api_client_logger
|
||||
from app.core.constants import DEFAULT_TIMEOUT
|
||||
|
||||
logger = get_api_client_logger()
|
||||
|
||||
class ApiClient(ABC):
|
||||
"""API客户端基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
pass
|
||||
|
||||
|
||||
class GeminiApiClient(ApiClient):
|
||||
"""Gemini API客户端"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
def _get_real_model(self, model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取可用的 Gemini 模型列表"""
|
||||
timeout = httpx.Timeout(timeout=5)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models?key={api_key}"
|
||||
try:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status() # 如果状态码不是 2xx,则引发 HTTPStatusError
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败: {e.response.status_code}")
|
||||
logger.error(e.response.text)
|
||||
# 返回 None 而不是抛出异常,以便上层处理
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求模型列表失败: {e}")
|
||||
# 返回 None 而不是抛出异常
|
||||
return None
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/openai/models"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/embeddings"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"input": input,
|
||||
"model": model,
|
||||
}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/images/generations"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
266
app/service/config/config_service.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
配置服务模块
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import insert, update
|
||||
|
||||
from app.config.config import Settings as ConfigSettings
|
||||
from app.config.config import settings
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings
|
||||
from app.database.services import get_all_settings
|
||||
from app.log.logger import get_config_routes_logger
|
||||
from app.service.key.key_manager import (
|
||||
get_key_manager_instance,
|
||||
reset_key_manager_instance,
|
||||
)
|
||||
from app.service.model.model_service import ModelService
|
||||
|
||||
logger = get_config_routes_logger()
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""配置服务类,用于管理应用程序配置"""
|
||||
|
||||
@staticmethod
|
||||
async def get_config() -> Dict[str, Any]:
|
||||
return settings.model_dump()
|
||||
|
||||
@staticmethod
|
||||
async def update_config(config_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
for key, value in config_data.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
logger.debug(f"Updated setting in memory: {key}")
|
||||
|
||||
# 获取现有设置
|
||||
existing_settings_raw: List[Dict[str, Any]] = await get_all_settings()
|
||||
existing_settings_map: Dict[str, Dict[str, Any]] = {
|
||||
s["key"]: s for s in existing_settings_raw
|
||||
}
|
||||
existing_keys = set(existing_settings_map.keys())
|
||||
|
||||
settings_to_update: List[Dict[str, Any]] = []
|
||||
settings_to_insert: List[Dict[str, Any]] = []
|
||||
now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
|
||||
|
||||
# 准备要更新或插入的数据
|
||||
for key, value in config_data.items():
|
||||
# 处理不同类型的值
|
||||
if isinstance(value, list):
|
||||
db_value = json.dumps(value)
|
||||
elif isinstance(value, dict): # 新增对 dict 类型的处理
|
||||
db_value = json.dumps(value)
|
||||
elif isinstance(value, bool):
|
||||
db_value = str(value).lower()
|
||||
else:
|
||||
db_value = str(value)
|
||||
|
||||
# 仅当值发生变化时才更新
|
||||
if key in existing_keys and existing_settings_map[key]["value"] == db_value:
|
||||
continue
|
||||
|
||||
description = f"{key}配置项"
|
||||
|
||||
data = {
|
||||
"key": key,
|
||||
"value": db_value,
|
||||
"description": description,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
if key in existing_keys:
|
||||
# Preserve original description if not explicitly provided
|
||||
data["description"] = existing_settings_map[key].get(
|
||||
"description", description
|
||||
)
|
||||
settings_to_update.append(data)
|
||||
else:
|
||||
data["created_at"] = now
|
||||
settings_to_insert.append(data)
|
||||
|
||||
# 在事务中执行批量插入和更新
|
||||
if settings_to_insert or settings_to_update:
|
||||
try:
|
||||
async with database.transaction():
|
||||
if settings_to_insert:
|
||||
query_insert = insert(Settings).values(settings_to_insert)
|
||||
await database.execute(query=query_insert)
|
||||
logger.info(
|
||||
f"Bulk inserted {len(settings_to_insert)} settings."
|
||||
)
|
||||
|
||||
if settings_to_update:
|
||||
for setting_data in settings_to_update:
|
||||
query_update = (
|
||||
update(Settings)
|
||||
.where(Settings.key == setting_data["key"])
|
||||
.values(
|
||||
value=setting_data["value"],
|
||||
description=setting_data["description"],
|
||||
updated_at=setting_data["updated_at"],
|
||||
)
|
||||
)
|
||||
await database.execute(query=query_update)
|
||||
logger.info(f"Updated {len(settings_to_update)} settings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to bulk update/insert settings: {str(e)}")
|
||||
raise # Re-raise the exception after logging
|
||||
|
||||
# 重置并重新初始化 KeyManager
|
||||
try:
|
||||
await reset_key_manager_instance()
|
||||
await get_key_manager_instance(settings.API_KEYS)
|
||||
logger.info("KeyManager instance re-initialized with updated settings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to re-initialize KeyManager: {str(e)}")
|
||||
# Decide if this error should prevent returning the updated config
|
||||
# For now, we log the error and continue
|
||||
|
||||
return await ConfigService.get_config()
|
||||
|
||||
@staticmethod
|
||||
async def delete_key(key_to_delete: str) -> Dict[str, Any]:
|
||||
"""删除单个API密钥"""
|
||||
# 确保 settings.API_KEYS 是一个列表
|
||||
if not isinstance(settings.API_KEYS, list):
|
||||
settings.API_KEYS = []
|
||||
|
||||
original_keys_count = len(settings.API_KEYS)
|
||||
# 创建一个不包含待删除密钥的新列表
|
||||
updated_api_keys = [k for k in settings.API_KEYS if k != key_to_delete]
|
||||
|
||||
if len(updated_api_keys) < original_keys_count:
|
||||
# 密钥已找到并从列表中移除
|
||||
settings.API_KEYS = updated_api_keys # 首先更新内存中的 settings
|
||||
# 使用 update_config 持久化更改,它同时处理数据库和 KeyManager
|
||||
await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
|
||||
logger.info(f"密钥 '{key_to_delete}' 已成功删除。")
|
||||
return {"success": True, "message": f"密钥 '{key_to_delete}' 已成功删除。"}
|
||||
else:
|
||||
# 未找到密钥
|
||||
logger.warning(f"尝试删除密钥 '{key_to_delete}',但未找到该密钥。")
|
||||
return {"success": False, "message": f"未找到密钥 '{key_to_delete}'。"}
|
||||
|
||||
@staticmethod
|
||||
async def delete_selected_keys(keys_to_delete: List[str]) -> Dict[str, Any]:
|
||||
"""批量删除选定的API密钥"""
|
||||
if not isinstance(settings.API_KEYS, list):
|
||||
settings.API_KEYS = []
|
||||
|
||||
deleted_count = 0
|
||||
not_found_keys: List[str] = []
|
||||
|
||||
current_api_keys = list(settings.API_KEYS) # 创建副本以进行修改
|
||||
keys_actually_removed: List[str] = []
|
||||
|
||||
for key_to_del in keys_to_delete:
|
||||
if key_to_del in current_api_keys:
|
||||
current_api_keys.remove(key_to_del)
|
||||
keys_actually_removed.append(key_to_del)
|
||||
deleted_count += 1
|
||||
else:
|
||||
not_found_keys.append(key_to_del)
|
||||
|
||||
if deleted_count > 0:
|
||||
settings.API_KEYS = current_api_keys # 更新内存中的 settings
|
||||
await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
|
||||
logger.info(
|
||||
f"成功删除 {deleted_count} 个密钥。密钥: {keys_actually_removed}"
|
||||
)
|
||||
message = f"成功删除 {deleted_count} 个密钥。"
|
||||
if not_found_keys:
|
||||
message += f" {len(not_found_keys)} 个密钥未找到: {not_found_keys}。"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"deleted_count": deleted_count,
|
||||
"not_found_keys": not_found_keys,
|
||||
}
|
||||
else:
|
||||
message = "没有密钥被删除。"
|
||||
if not_found_keys: # 如果提供了密钥但都未找到
|
||||
message = f"所有 {len(not_found_keys)} 个指定的密钥均未找到: {not_found_keys}。"
|
||||
elif not keys_to_delete: # 如果 keys_to_delete 列表为空
|
||||
message = "未指定要删除的密钥。"
|
||||
logger.warning(message)
|
||||
return {
|
||||
"success": False,
|
||||
"message": message,
|
||||
"deleted_count": 0,
|
||||
"not_found_keys": not_found_keys,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def reset_config() -> Dict[str, Any]:
|
||||
"""
|
||||
重置配置:优先从系统环境变量加载,然后从 .env 文件加载,
|
||||
更新内存中的 settings 对象,并刷新 KeyManager。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 重置后的配置字典
|
||||
"""
|
||||
# 1. 重新加载配置对象,它应该处理环境变量和 .env 的优先级
|
||||
_reload_settings()
|
||||
logger.info(
|
||||
"Settings object reloaded, prioritizing system environment variables then .env file."
|
||||
)
|
||||
|
||||
# 2. 重置并重新初始化 KeyManager
|
||||
try:
|
||||
await reset_key_manager_instance()
|
||||
# 确保使用更新后的 settings 中的 API_KEYS
|
||||
await get_key_manager_instance(settings.API_KEYS)
|
||||
logger.info("KeyManager instance re-initialized with reloaded settings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to re-initialize KeyManager during reset: {str(e)}")
|
||||
# 根据需要决定是否抛出异常或继续
|
||||
# 这里选择记录错误并继续
|
||||
|
||||
# 3. 返回更新后的配置
|
||||
return await ConfigService.get_config()
|
||||
|
||||
@staticmethod
|
||||
async def fetch_ui_models() -> List[Dict[str, Any]]:
|
||||
"""获取用于UI显示的模型列表"""
|
||||
try:
|
||||
key_manager = await get_key_manager_instance()
|
||||
model_service = ModelService()
|
||||
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
logger.error("No valid API keys available to fetch model list for UI.")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="No valid API keys available to fetch model list.",
|
||||
)
|
||||
|
||||
models = await model_service.get_gemini_openai_models(api_key)
|
||||
return models
|
||||
except HTTPException as e:
|
||||
# Re-raise HTTPExceptions directly if they are already specific
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch models for UI in ConfigService: {e}", exc_info=True
|
||||
)
|
||||
# Raise a generic HTTPException for other errors
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch models for UI: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# 重新加载配置的函数
|
||||
def _reload_settings():
|
||||
"""重新加载环境变量并更新配置"""
|
||||
# 显式加载 .env 文件,覆盖现有环境变量
|
||||
load_dotenv(find_dotenv(), override=True)
|
||||
# 更新现有 settings 对象的属性,而不是新建实例
|
||||
for key, value in ConfigSettings().model_dump().items():
|
||||
setattr(settings, key, value)
|
||||
82
app/service/embedding/embedding_service.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import datetime
|
||||
import time
|
||||
import re # For potential status code parsing from generic errors
|
||||
from typing import List, Union
|
||||
|
||||
import openai
|
||||
from openai import APIStatusError # Import specific error type
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_embeddings_logger
|
||||
from app.database.services import add_error_log, add_request_log # Import DB logging functions
|
||||
|
||||
logger = get_embeddings_logger()
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
|
||||
async def create_embedding(
|
||||
self, input_text: Union[str, List[str]], model: str, api_key: str
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""Create embeddings using OpenAI API with database logging"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
error_log_msg = ""
|
||||
# Prepare request message for logging (truncate if list or long string)
|
||||
if isinstance(input_text, list):
|
||||
request_msg_log = {"input_truncated": [str(item)[:100] + "..." if len(str(item)) > 100 else str(item) for item in input_text[:5]]}
|
||||
if len(input_text) > 5:
|
||||
request_msg_log["input_truncated"].append("...")
|
||||
else:
|
||||
request_msg_log = {"input_truncated": input_text[:1000] + "..." if len(input_text) > 1000 else input_text}
|
||||
|
||||
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
|
||||
response = client.embeddings.create(input=input_text, model=model)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except APIStatusError as e:
|
||||
is_success = False
|
||||
status_code = e.status_code
|
||||
error_log_msg = f"OpenAI API error: {e}"
|
||||
logger.error(f"Error creating embedding (APIStatusError): {error_log_msg}")
|
||||
raise e # Re-raise the specific error
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Generic error: {e}"
|
||||
logger.error(f"Error creating embedding (Exception): {error_log_msg}")
|
||||
# Try to parse status code from generic error (less reliable)
|
||||
match = re.search(r"status code (\d+)", str(e))
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500 # Default if parsing fails
|
||||
raise e # Re-raise the generic error
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
if not is_success:
|
||||
# Log error to database if it failed
|
||||
await add_error_log(
|
||||
gemini_key=api_key, # Using gemini_key parameter name for consistency
|
||||
model_name=model,
|
||||
error_type="openai-embedding",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request_msg_log
|
||||
)
|
||||
# Log request outcome to database regardless of success/failure
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
155
app/service/error_log/error_log_service.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.connection import database
|
||||
from app.database.models import ErrorLog
|
||||
from app.log.logger import get_error_log_logger
|
||||
|
||||
logger = get_error_log_logger()
|
||||
|
||||
|
||||
async def delete_old_error_logs():
|
||||
"""
|
||||
Deletes error logs older than a specified number of days,
|
||||
based on the AUTO_DELETE_ERROR_LOGS_ENABLED and AUTO_DELETE_ERROR_LOGS_DAYS settings.
|
||||
"""
|
||||
if not settings.AUTO_DELETE_ERROR_LOGS_ENABLED:
|
||||
logger.info("Auto-deletion of error logs is disabled. Skipping.")
|
||||
return
|
||||
|
||||
days_to_keep = settings.AUTO_DELETE_ERROR_LOGS_DAYS
|
||||
if not isinstance(days_to_keep, int) or days_to_keep <= 0:
|
||||
logger.error(
|
||||
f"Invalid AUTO_DELETE_ERROR_LOGS_DAYS value: {days_to_keep}. Must be a positive integer. Skipping deletion."
|
||||
)
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
logger.info(
|
||||
f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."
|
||||
)
|
||||
|
||||
try:
|
||||
if not database.is_connected:
|
||||
await database.connect()
|
||||
logger.info("Database connection established for deleting error logs.")
|
||||
|
||||
# First, count how many logs will be deleted (optional, for logging)
|
||||
count_query = select(func.count(ErrorLog.id)).where(
|
||||
ErrorLog.request_time < cutoff_date
|
||||
)
|
||||
num_logs_to_delete = await database.fetch_val(count_query)
|
||||
|
||||
if num_logs_to_delete == 0:
|
||||
logger.info(
|
||||
"No error logs found older than the specified period. No deletion needed."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Found {num_logs_to_delete} error logs to delete.")
|
||||
|
||||
# Perform the deletion
|
||||
query = delete(ErrorLog).where(ErrorLog.request_time < cutoff_date)
|
||||
await database.execute(query)
|
||||
logger.info(
|
||||
f"Successfully deleted {num_logs_to_delete} error logs older than {days_to_keep} days."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during automatic deletion of error logs: {e}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
async def process_get_error_logs(
|
||||
limit: int,
|
||||
offset: int,
|
||||
key_search: Optional[str],
|
||||
error_search: Optional[str],
|
||||
error_code_search: Optional[str],
|
||||
start_date: Optional[datetime],
|
||||
end_date: Optional[datetime],
|
||||
sort_by: str,
|
||||
sort_order: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理错误日志的检索,支持分页和过滤。
|
||||
"""
|
||||
try:
|
||||
logs_data = await db_services.get_error_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
total_count = await db_services.get_error_logs_count(
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return {"logs": logs_data, "total": total_count}
|
||||
except Exception as e:
|
||||
logger.error(f"Service error in process_get_error_logs: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def process_get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理特定错误日志详细信息的检索。
|
||||
如果未找到,则返回 None。
|
||||
"""
|
||||
try:
|
||||
log_details = await db_services.get_error_log_details(log_id=log_id)
|
||||
return log_details
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_get_error_log_details for ID {log_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
"""
|
||||
按 ID 批量删除错误日志。
|
||||
返回尝试删除的日志数量。
|
||||
"""
|
||||
if not log_ids:
|
||||
return 0
|
||||
try:
|
||||
deleted_count = await db_services.delete_error_logs_by_ids(log_ids)
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_delete_error_logs_by_ids for IDs {log_ids}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_delete_error_log_by_id(log_id: int) -> bool:
|
||||
"""
|
||||
按 ID 删除单个错误日志。
|
||||
如果删除成功(或找到日志并尝试删除),则返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
success = await db_services.delete_error_log_by_id(log_id)
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_delete_error_log_by_id for ID {log_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
@@ -1,14 +1,15 @@
|
||||
import base64
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
import base64
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_image_create_logger
|
||||
from app.core.uploader import ImageUploaderFactory
|
||||
from app.schemas.openai_models import ImageGenerationRequest
|
||||
from app.config.config import settings
|
||||
from app.core.constants import VALID_IMAGE_RATIOS
|
||||
from app.domain.openai_models import ImageGenerationRequest
|
||||
from app.log.logger import get_image_create_logger
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
logger = get_image_create_logger()
|
||||
|
||||
@@ -16,7 +17,6 @@ logger = get_image_create_logger()
|
||||
class ImageCreateService:
|
||||
def __init__(self, aspect_ratio="1:1"):
|
||||
self.image_model = settings.CREATE_IMAGE_MODEL
|
||||
self.paid_key = settings.PAID_KEY
|
||||
self.aspect_ratio = aspect_ratio
|
||||
|
||||
def parse_prompt_parameters(self, prompt: str) -> tuple:
|
||||
@@ -26,35 +26,34 @@ class ImageCreateService:
|
||||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# 默认值
|
||||
n = 1
|
||||
aspect_ratio = self.aspect_ratio
|
||||
|
||||
|
||||
# 解析n参数
|
||||
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||
n_match = re.search(r"{n:(\d+)}", prompt)
|
||||
if n_match:
|
||||
n = int(n_match.group(1))
|
||||
if n < 1 or n > 4:
|
||||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||
prompt = prompt.replace(n_match.group(0), "").strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
|
||||
if ratio_match:
|
||||
aspect_ratio = ratio_match.group(1)
|
||||
valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||
if aspect_ratio not in valid_ratios:
|
||||
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||
raise ValueError(
|
||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}"
|
||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||
)
|
||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||
|
||||
prompt = prompt.replace(ratio_match.group(0), "").strip()
|
||||
|
||||
return prompt, n, aspect_ratio
|
||||
|
||||
def generate_images(self, request: ImageGenerationRequest):
|
||||
client = genai.Client(api_key=self.paid_key)
|
||||
|
||||
client = genai.Client(api_key=settings.PAID_KEY)
|
||||
|
||||
if request.size == "1024x1024":
|
||||
self.aspect_ratio = "1:1"
|
||||
elif request.size == "1792x1024":
|
||||
@@ -67,13 +66,15 @@ class ImageCreateService:
|
||||
)
|
||||
|
||||
# 解析prompt中的参数
|
||||
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt)
|
||||
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
|
||||
request.prompt
|
||||
)
|
||||
request.prompt = cleaned_prompt
|
||||
|
||||
|
||||
# 如果prompt中指定了n,则覆盖请求中的n
|
||||
if prompt_n > 1:
|
||||
request.n = prompt_n
|
||||
|
||||
|
||||
# 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
|
||||
if prompt_ratio != self.aspect_ratio:
|
||||
self.aspect_ratio = prompt_ratio
|
||||
@@ -87,7 +88,6 @@ class ImageCreateService:
|
||||
aspect_ratio=self.aspect_ratio,
|
||||
safety_filter_level="BLOCK_LOW_AND_ABOVE",
|
||||
person_generation="ALLOW_ADULT",
|
||||
# language="auto"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -96,46 +96,49 @@ class ImageCreateService:
|
||||
for index, generated_image in enumerate(response.generated_images):
|
||||
image_data = generated_image.image.image_bytes
|
||||
image_uploader = None
|
||||
|
||||
|
||||
if request.response_format == "b64_json":
|
||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||
images_data.append({
|
||||
"b64_json": base64_image,
|
||||
"revised_prompt": request.prompt
|
||||
})
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
images_data.append(
|
||||
{"b64_json": base64_image, "revised_prompt": request.prompt}
|
||||
)
|
||||
else:
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
|
||||
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.SMMS_SECRET_TOKEN
|
||||
api_key=settings.SMMS_SECRET_TOKEN,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.PICGO_API_KEY
|
||||
api_key=settings.PICGO_API_KEY,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
|
||||
)
|
||||
|
||||
upload_response = image_uploader.upload(image_data, filename)
|
||||
|
||||
images_data.append({
|
||||
"url": f"{upload_response.data.url}",
|
||||
"revised_prompt": request.prompt
|
||||
})
|
||||
images_data.append(
|
||||
{
|
||||
"url": f"{upload_response.data.url}",
|
||||
"revised_prompt": request.prompt,
|
||||
}
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"created": int(time.time()), # Current timestamp
|
||||
"data": images_data
|
||||
"data": images_data,
|
||||
}
|
||||
return response_data
|
||||
else:
|
||||
@@ -147,9 +150,13 @@ class ImageCreateService:
|
||||
if image_datas:
|
||||
markdown_images = []
|
||||
for index, image_data in enumerate(image_datas):
|
||||
if 'url' in image_data:
|
||||
markdown_images.append(f"")
|
||||
if "url" in image_data:
|
||||
markdown_images.append(
|
||||
f""
|
||||
)
|
||||
else:
|
||||
# 如果是base64格式,创建data URL
|
||||
markdown_images.append(f"")
|
||||
markdown_images.append(
|
||||
f""
|
||||
)
|
||||
return "\n".join(markdown_images)
|
||||
308
app/service/key/key_manager.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import Dict
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_key_manager_logger
|
||||
|
||||
logger = get_key_manager_logger()
|
||||
|
||||
|
||||
class KeyManager:
|
||||
def __init__(self, api_keys: list):
|
||||
self.api_keys = api_keys
|
||||
self.key_cycle = cycle(api_keys)
|
||||
self.key_cycle_lock = asyncio.Lock()
|
||||
self.failure_count_lock = asyncio.Lock()
|
||||
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
||||
self.MAX_FAILURES = settings.MAX_FAILURES
|
||||
self.paid_key = settings.PAID_KEY
|
||||
|
||||
async def get_paid_key(self) -> str:
|
||||
return self.paid_key
|
||||
|
||||
async def get_next_key(self) -> str:
|
||||
"""获取下一个API key"""
|
||||
async with self.key_cycle_lock:
|
||||
return next(self.key_cycle)
|
||||
|
||||
async def is_key_valid(self, key: str) -> bool:
|
||||
"""检查key是否有效"""
|
||||
async with self.failure_count_lock:
|
||||
return self.key_failure_counts[key] < self.MAX_FAILURES
|
||||
|
||||
async def reset_failure_counts(self):
|
||||
"""重置所有key的失败计数"""
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
self.key_failure_counts[key] = 0
|
||||
|
||||
async def reset_key_failure_count(self, key: str) -> bool:
|
||||
"""重置指定key的失败计数"""
|
||||
async with self.failure_count_lock:
|
||||
if key in self.key_failure_counts:
|
||||
self.key_failure_counts[key] = 0
|
||||
logger.info(f"Reset failure count for key: {key}")
|
||||
return True
|
||||
logger.warning(
|
||||
f"Attempt to reset failure count for non-existent key: {key}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def get_next_working_key(self) -> str:
|
||||
"""获取下一可用的API key"""
|
||||
initial_key = await self.get_next_key()
|
||||
current_key = initial_key
|
||||
|
||||
while True:
|
||||
if await self.is_key_valid(current_key):
|
||||
return current_key
|
||||
|
||||
current_key = await self.get_next_key()
|
||||
if current_key == initial_key:
|
||||
# await self.reset_failure_counts() 取消重置
|
||||
return current_key
|
||||
|
||||
async def handle_api_failure(self, api_key: str, retries: int) -> str:
|
||||
"""处理API调用失败"""
|
||||
async with self.failure_count_lock:
|
||||
self.key_failure_counts[api_key] += 1
|
||||
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
if retries < settings.MAX_RETRIES:
|
||||
return await self.get_next_working_key()
|
||||
else:
|
||||
return ""
|
||||
|
||||
def get_fail_count(self, key: str) -> int:
|
||||
"""获取指定密钥的失败次数"""
|
||||
return self.key_failure_counts.get(key, 0)
|
||||
|
||||
async def get_keys_by_status(self) -> dict:
|
||||
"""获取分类后的API key列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
invalid_keys = {}
|
||||
|
||||
async with self.failure_count_lock:
|
||||
for key in self.api_keys:
|
||||
fail_count = self.key_failure_counts[key]
|
||||
if fail_count < self.MAX_FAILURES:
|
||||
valid_keys[key] = fail_count
|
||||
else:
|
||||
invalid_keys[key] = fail_count
|
||||
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||
|
||||
async def get_first_valid_key(self) -> str:
|
||||
"""获取第一个有效的API key"""
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
if self.key_failure_counts[key] < self.MAX_FAILURES:
|
||||
return key
|
||||
# 如果所有 key 都无效,或者列表为空,则尝试返回第一个(如果列表不为空)
|
||||
# 或者根据具体逻辑处理,这里保持原样,可能在空列表或全无效时需要调整
|
||||
if self.api_keys:
|
||||
return self.api_keys[0]
|
||||
# 如果 api_keys 为空,这里会出问题。实际应用中应有非空保证或更好处理。
|
||||
# 为了保持接口一致性,如果列表为空,可能应该抛出异常或返回特定值。
|
||||
# 暂且假设 api_keys 不会为空,或者调用者处理后续的空 key 问题。
|
||||
# 根据现有代码,如果api_keys为空,self.api_keys[0]会报错。
|
||||
# 如果没有有效key且列表不空,返回第一个。若列表为空,这里会出IndexError。
|
||||
# 更安全的做法是:
|
||||
if not self.api_keys:
|
||||
logger.warning("API key list is empty, cannot get first valid key.")
|
||||
# Depending on desired behavior, either raise error or return an indicator like "" or None
|
||||
# For now, let's allow it to potentially fail if a key is expected by caller
|
||||
# but it's better to be explicit. Let's return empty string for consistency with handle_api_failure
|
||||
return ""
|
||||
return self.api_keys[
|
||||
0
|
||||
] # Fallback to the first key if no key is "valid" but list is not empty
|
||||
|
||||
|
||||
_singleton_instance = None
|
||||
_singleton_lock = asyncio.Lock()
|
||||
_preserved_failure_counts: Dict[str, int] | None = None
|
||||
_preserved_old_api_keys_for_reset: list | None = None
|
||||
_preserved_next_key_in_cycle: str | None = None
|
||||
|
||||
|
||||
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
|
||||
"""
|
||||
获取 KeyManager 单例实例。
|
||||
|
||||
如果尚未创建实例,将使用提供的 api_keys 初始化 KeyManager。
|
||||
如果已创建实例,则忽略 api_keys 参数,返回现有单例。
|
||||
如果在重置后调用,会尝试恢复之前的状态(失败计数、循环位置)。
|
||||
"""
|
||||
global _singleton_instance, _preserved_failure_counts, _preserved_old_api_keys_for_reset, _preserved_next_key_in_cycle
|
||||
|
||||
async with _singleton_lock:
|
||||
if _singleton_instance is None:
|
||||
if api_keys is None:
|
||||
# This case needs careful handling. If it's the very first call, api_keys are required.
|
||||
# If it's after a reset and no api_keys are provided, what should happen?
|
||||
# The original ValueError was "API keys are required to initialize the KeyManager".
|
||||
# Let's assume if api_keys is None here, it's an error unless we are restoring from non-None _preserved_old_api_keys_for_reset.
|
||||
# However, the user's request implies new api_keys will be part of the reset flow.
|
||||
# For now, stick to a strict requirement for api_keys if _singleton_instance is None.
|
||||
raise ValueError(
|
||||
"API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
)
|
||||
if not api_keys: # Handle case where api_keys is an empty list
|
||||
logger.warning(
|
||||
"Initializing KeyManager with an empty list of API keys."
|
||||
)
|
||||
# Consider if this should be an error or allowed. Current KeyManager supports it.
|
||||
|
||||
_singleton_instance = KeyManager(api_keys)
|
||||
logger.info(
|
||||
f"KeyManager instance created/re-created with {len(api_keys)} API keys."
|
||||
)
|
||||
|
||||
# 1. 恢复失败计数
|
||||
if _preserved_failure_counts:
|
||||
# Initialize new instance's failure_counts for all new keys to 0
|
||||
current_failure_counts = {
|
||||
key: 0 for key in _singleton_instance.api_keys
|
||||
}
|
||||
# Inherit counts for keys that exist in both old and new lists
|
||||
for key, count in _preserved_failure_counts.items():
|
||||
if key in current_failure_counts:
|
||||
current_failure_counts[key] = count
|
||||
_singleton_instance.key_failure_counts = current_failure_counts
|
||||
logger.info("Inherited failure counts for applicable keys.")
|
||||
_preserved_failure_counts = None # Clear after use
|
||||
|
||||
# 2. 调整 key_cycle 的起始点
|
||||
start_key_for_new_cycle = None
|
||||
if (
|
||||
_preserved_old_api_keys_for_reset
|
||||
and _preserved_next_key_in_cycle
|
||||
and _singleton_instance.api_keys # Ensure new api_keys list is not empty
|
||||
):
|
||||
try:
|
||||
# Find the index of the preserved next key in the *old* list
|
||||
start_idx_in_old = _preserved_old_api_keys_for_reset.index(
|
||||
_preserved_next_key_in_cycle
|
||||
)
|
||||
|
||||
# Iterate through the old key list (circularly) starting from _preserved_next_key_in_cycle
|
||||
# Find the first key that also exists in the new api_keys list
|
||||
for i in range(len(_preserved_old_api_keys_for_reset)):
|
||||
current_old_key_idx = (start_idx_in_old + i) % len(
|
||||
_preserved_old_api_keys_for_reset
|
||||
)
|
||||
key_candidate = _preserved_old_api_keys_for_reset[
|
||||
current_old_key_idx
|
||||
]
|
||||
if key_candidate in _singleton_instance.api_keys:
|
||||
start_key_for_new_cycle = key_candidate
|
||||
break
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Preserved next key '{_preserved_next_key_in_cycle}' not found in preserved old API keys. "
|
||||
"New cycle will start from the beginning of the new list."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error determining start key for new cycle from preserved state: {e}. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
|
||||
if start_key_for_new_cycle and _singleton_instance.api_keys:
|
||||
try:
|
||||
# Find the index of the determined start_key in the new api_keys list
|
||||
target_idx = _singleton_instance.api_keys.index(
|
||||
start_key_for_new_cycle
|
||||
)
|
||||
# Advance the new cycle by calling next() target_idx times
|
||||
# This positions the cycle so that the *next* call to next() will yield start_key_for_new_cycle
|
||||
for _ in range(target_idx):
|
||||
next(_singleton_instance.key_cycle)
|
||||
logger.info(
|
||||
f"Key cycle in new instance advanced. Next call to get_next_key() will yield: {start_key_for_new_cycle}"
|
||||
)
|
||||
except ValueError:
|
||||
# This should not happen if start_key_for_new_cycle was correctly found in api_keys
|
||||
logger.warning(
|
||||
f"Determined start key '{start_key_for_new_cycle}' not found in new API keys during cycle advancement. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
except (
|
||||
StopIteration
|
||||
): # Should not happen with cycle unless api_keys is empty, handled by _singleton_instance.api_keys check
|
||||
logger.error(
|
||||
"StopIteration while advancing key cycle, implies empty new API key list previously missed."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error advancing new key cycle: {e}. Cycle will start from beginning."
|
||||
)
|
||||
else:
|
||||
if _singleton_instance.api_keys:
|
||||
logger.info(
|
||||
"New key cycle will start from the beginning of the new API key list (no specific start key determined or needed)."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"New key cycle not applicable as the new API key list is empty."
|
||||
)
|
||||
|
||||
# 清理所有保存的状态
|
||||
_preserved_old_api_keys_for_reset = None
|
||||
_preserved_next_key_in_cycle = None
|
||||
# _preserved_failure_counts already cleared
|
||||
|
||||
return _singleton_instance
|
||||
|
||||
|
||||
async def reset_key_manager_instance():
|
||||
"""
|
||||
重置 KeyManager 单例实例。
|
||||
将保存当前实例的状态(失败计数、旧 API keys、下一个 key 提示)
|
||||
以供下一次 get_key_manager_instance 调用时恢复。
|
||||
"""
|
||||
global _singleton_instance, _preserved_failure_counts, _preserved_old_api_keys_for_reset, _preserved_next_key_in_cycle
|
||||
async with _singleton_lock:
|
||||
if _singleton_instance:
|
||||
# 1. 保存失败计数
|
||||
_preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
|
||||
|
||||
# 2. 保存旧的 API keys 列表
|
||||
_preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
|
||||
|
||||
# 3. 保存 key_cycle 的下一个 key 提示
|
||||
# This should be the key that get_next_key() would return next.
|
||||
try:
|
||||
if (
|
||||
_singleton_instance.api_keys
|
||||
): # Only if there are keys to cycle through
|
||||
# Calling get_next_key() consumes one key and returns it. This is the key
|
||||
# we want the new cycle to effectively start with.
|
||||
_preserved_next_key_in_cycle = (
|
||||
await _singleton_instance.get_next_key()
|
||||
)
|
||||
else:
|
||||
_preserved_next_key_in_cycle = None # No keys, so no next key
|
||||
except (
|
||||
StopIteration
|
||||
): # Should be caught by "if _singleton_instance.api_keys"
|
||||
logger.warning(
|
||||
"Could not preserve next key hint: key cycle was empty or exhausted in old instance."
|
||||
)
|
||||
_preserved_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_next_key_in_cycle = None
|
||||
|
||||
_singleton_instance = None
|
||||
logger.info(
|
||||
"KeyManager instance has been reset. State (failure counts, old keys, next key hint) preserved for next instantiation."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"KeyManager instance was not set (or already reset), no reset action performed."
|
||||
)
|
||||
93
app/service/model/model_service.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_model_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
|
||||
logger = get_model_logger()
|
||||
|
||||
|
||||
class ModelService:
|
||||
async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""使用 GeminiApiClient 获取并过滤模型列表"""
|
||||
api_client = GeminiApiClient(base_url=settings.BASE_URL) # 实例化客户端
|
||||
gemini_models = await api_client.get_models(api_key)
|
||||
|
||||
if gemini_models is None:
|
||||
logger.error("从 API 客户端获取模型列表失败。")
|
||||
return None
|
||||
|
||||
try:
|
||||
filtered_models_list = []
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
if model_id not in settings.FILTERED_MODELS:
|
||||
filtered_models_list.append(model)
|
||||
else:
|
||||
logger.debug(f"Filtered out model: {model_id}")
|
||||
|
||||
gemini_models["models"] = filtered_models_list
|
||||
return gemini_models
|
||||
except Exception as e:
|
||||
logger.error(f"处理模型列表时出错: {e}")
|
||||
return None
|
||||
|
||||
async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取 Gemini 模型并转换为 OpenAI 格式"""
|
||||
gemini_models = await self.get_gemini_models(api_key)
|
||||
if gemini_models is None:
|
||||
return None
|
||||
|
||||
return await self.convert_to_openai_models_format(gemini_models)
|
||||
|
||||
async def convert_to_openai_models_format(
|
||||
self, gemini_models: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
openai_format = {"object": "list", "data": [], "success": True}
|
||||
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
openai_model = {
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": int(datetime.now(timezone.utc).timestamp()),
|
||||
"owned_by": "google",
|
||||
"permission": [],
|
||||
"root": model["name"],
|
||||
"parent": None,
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
if model_id in settings.SEARCH_MODELS:
|
||||
search_model = openai_model.copy()
|
||||
search_model["id"] = f"{model_id}-search"
|
||||
openai_format["data"].append(search_model)
|
||||
if model_id in settings.IMAGE_MODELS:
|
||||
image_model = openai_model.copy()
|
||||
image_model["id"] = f"{model_id}-image"
|
||||
openai_format["data"].append(image_model)
|
||||
if model_id in settings.THINKING_MODELS:
|
||||
non_thinking_model = openai_model.copy()
|
||||
non_thinking_model["id"] = f"{model_id}-non-thinking"
|
||||
openai_format["data"].append(non_thinking_model)
|
||||
|
||||
if settings.CREATE_IMAGE_MODEL:
|
||||
image_model = openai_model.copy()
|
||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
openai_format["data"].append(image_model)
|
||||
return openai_format
|
||||
|
||||
async def check_model_support(self, model: str) -> bool:
|
||||
if not model or not isinstance(model, str):
|
||||
return False
|
||||
|
||||
model = model.strip()
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
return model in settings.SEARCH_MODELS
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
return model in settings.IMAGE_MODELS
|
||||
|
||||
return model not in settings.FILTERED_MODELS
|
||||
197
app/service/openai_compatiable/openai_compatiable_service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.service.client.api_client import OpenaiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
class OpenAICompatiableService:
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.key_manager = key_manager
|
||||
self.base_url = base_url
|
||||
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
return await self.api_client.get_models(api_key)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, request_dict, api_key)
|
||||
return await self._handle_normal_completion(request.model, request_dict, api_key)
|
||||
|
||||
async def generate_images(
|
||||
self,
|
||||
request: ImageGenerationRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成图片"""
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
api_key = settings.PAID_KEY
|
||||
return await self.api_client.generate_images(request_dict, api_key)
|
||||
|
||||
async def create_embeddings(
|
||||
self,
|
||||
input_text: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建嵌入"""
|
||||
return await self.api_client.create_embeddings(input_text, model, api_key)
|
||||
|
||||
async def _handle_normal_completion(
|
||||
self, model: str, request: dict, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
try:
|
||||
response = await self.api_client.generate_content(request, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
# Try to parse status code from exception
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-compatiable-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: dict, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
|
||||
while retries < max_retries:
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, current_attempt_key
|
||||
):
|
||||
if line.startswith("data:"):
|
||||
# print(line)
|
||||
yield line + "\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
is_success = True
|
||||
status_code = 200
|
||||
break # 成功后退出循环
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
# Parse error code for logging
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# Log error to error log table
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-compatiable-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
|
||||
# Attempt to switch API Key
|
||||
# Ensure key_manager is available (might need adjustment if not always passed)
|
||||
if self.key_manager:
|
||||
api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
else:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
finally:
|
||||
# Log the final outcome of the streaming request
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=final_api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
# If the loop finished due to failure, yield error and DONE
|
||||
if not is_success and retries >= max_retries:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
50
app/service/request_log/request_log_service.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Service for request log operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from app import database
|
||||
from app.config.config import settings
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import Logger
|
||||
|
||||
logger = Logger.setup_logger("request_log_service")
|
||||
|
||||
|
||||
async def delete_old_request_logs_task():
|
||||
"""
|
||||
定时删除旧的请求日志。
|
||||
"""
|
||||
if not settings.AUTO_DELETE_REQUEST_LOGS_ENABLED:
|
||||
logger.info(
|
||||
"Auto-delete for request logs is disabled by settings. Skipping task."
|
||||
)
|
||||
return
|
||||
|
||||
days_to_keep = settings.AUTO_DELETE_REQUEST_LOGS_DAYS
|
||||
logger.info(
|
||||
f"Starting scheduled task to delete old request logs older than {days_to_keep} days."
|
||||
)
|
||||
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)
|
||||
|
||||
if not database.is_connected:
|
||||
logger.info("Connecting to database for request log deletion.")
|
||||
await database.connect()
|
||||
|
||||
result = await database.execute(query)
|
||||
logger.info(
|
||||
f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result.rowcount if result else 'N/A'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled request log deletion: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
255
app/service/stats/stats_service.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# app/service/stats_service.py
|
||||
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
|
||||
from app.database.connection import database
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import get_stats_logger
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
|
||||
class StatsService:
|
||||
"""Service class for handling statistics related operations."""
|
||||
|
||||
async def get_calls_in_last_seconds(self, seconds: int) -> dict[str, int]:
|
||||
"""获取过去 N 秒内的调用次数 (总数、成功、失败)"""
|
||||
try:
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(seconds=seconds)
|
||||
query = select(
|
||||
func.count(RequestLog.id).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
and_(
|
||||
RequestLog.status_code >= 200,
|
||||
RequestLog.status_code < 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("success"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
or_(
|
||||
RequestLog.status_code < 200,
|
||||
RequestLog.status_code >= 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(RequestLog.status_code is None, 1), # type: ignore
|
||||
else_=0,
|
||||
)
|
||||
).label("failure"),
|
||||
).where(RequestLog.request_time >= cutoff_time)
|
||||
result = await database.fetch_one(query)
|
||||
if result:
|
||||
return {
|
||||
"total": result["total"] or 0,
|
||||
"success": result["success"] or 0,
|
||||
"failure": result["failure"] or 0,
|
||||
}
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get calls in last {seconds} seconds: {e}")
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
|
||||
async def get_calls_in_last_minutes(self, minutes: int) -> dict[str, int]:
|
||||
"""获取过去 N 分钟内的调用次数 (总数、成功、失败)"""
|
||||
return await self.get_calls_in_last_seconds(minutes * 60)
|
||||
|
||||
async def get_calls_in_last_hours(self, hours: int) -> dict[str, int]:
|
||||
"""获取过去 N 小时内的调用次数 (总数、成功、失败)"""
|
||||
return await self.get_calls_in_last_seconds(hours * 3600)
|
||||
|
||||
async def get_calls_in_current_month(self) -> dict[str, int]:
|
||||
"""获取当前自然月内的调用次数 (总数、成功、失败)"""
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
start_of_month = now.replace(
|
||||
day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
query = select(
|
||||
func.count(RequestLog.id).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
and_(
|
||||
RequestLog.status_code >= 200,
|
||||
RequestLog.status_code < 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("success"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
or_(
|
||||
RequestLog.status_code < 200,
|
||||
RequestLog.status_code >= 300,
|
||||
),
|
||||
1,
|
||||
),
|
||||
(RequestLog.status_code is None, 1), # type: ignore
|
||||
else_=0,
|
||||
)
|
||||
).label("failure"),
|
||||
).where(RequestLog.request_time >= start_of_month)
|
||||
result = await database.fetch_one(query)
|
||||
if result:
|
||||
return {
|
||||
"total": result["total"] or 0,
|
||||
"success": result["success"] or 0,
|
||||
"failure": result["failure"] or 0,
|
||||
}
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get calls in current month: {e}")
|
||||
return {"total": 0, "success": 0, "failure": 0}
|
||||
|
||||
async def get_api_usage_stats(self) -> dict:
|
||||
"""获取所有需要的 API 使用统计数据 (总数、成功、失败)"""
|
||||
try:
|
||||
stats_1m = await self.get_calls_in_last_minutes(1)
|
||||
stats_1h = await self.get_calls_in_last_hours(1)
|
||||
stats_24h = await self.get_calls_in_last_hours(24)
|
||||
stats_month = await self.get_calls_in_current_month()
|
||||
|
||||
return {
|
||||
"calls_1m": stats_1m,
|
||||
"calls_1h": stats_1h,
|
||||
"calls_24h": stats_24h,
|
||||
"calls_month": stats_month,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get API usage stats: {e}")
|
||||
default_stat = {"total": 0, "success": 0, "failure": 0}
|
||||
return {
|
||||
"calls_1m": default_stat.copy(),
|
||||
"calls_1h": default_stat.copy(),
|
||||
"calls_24h": default_stat.copy(),
|
||||
"calls_month": default_stat.copy(),
|
||||
}
|
||||
|
||||
async def get_api_call_details(self, period: str) -> list[dict]:
|
||||
"""
|
||||
获取指定时间段内的 API 调用详情
|
||||
|
||||
Args:
|
||||
period: 时间段标识 ('1m', '1h', '24h')
|
||||
|
||||
Returns:
|
||||
包含调用详情的字典列表,每个字典包含 timestamp, key, model, status
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 period 无效
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
if period == "1m":
|
||||
start_time = now - datetime.timedelta(minutes=1)
|
||||
elif period == "1h":
|
||||
start_time = now - datetime.timedelta(hours=1)
|
||||
elif period == "24h":
|
||||
start_time = now - datetime.timedelta(hours=24)
|
||||
else:
|
||||
raise ValueError(f"无效的时间段标识: {period}")
|
||||
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.request_time.label("timestamp"),
|
||||
RequestLog.api_key.label("key"),
|
||||
RequestLog.model_name.label("model"),
|
||||
RequestLog.status_code, # We might need to map this to 'success'/'failure' later
|
||||
)
|
||||
.where(RequestLog.request_time >= start_time)
|
||||
.order_by(RequestLog.request_time.desc())
|
||||
) # Order by most recent first
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
# Convert results to list of dicts and map status_code
|
||||
details = []
|
||||
for row in results:
|
||||
status = "failure" # 默认状态为 failure,如果 status_code 有效且在 200-299 范围内则更新为 success
|
||||
if row["status_code"] is not None: # 检查 status_code 是否为空
|
||||
status = "success" if 200 <= row["status_code"] < 300 else "failure"
|
||||
details.append(
|
||||
{
|
||||
"timestamp": row[
|
||||
"timestamp"
|
||||
].isoformat(), # Use ISO format for JS compatibility
|
||||
"key": row["key"],
|
||||
"model": row["model"],
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Retrieved {len(details)} API call details for period '{period}'"
|
||||
)
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get API call details for period '{period}': {e}")
|
||||
# Re-raise the exception to be handled by the route
|
||||
raise
|
||||
|
||||
async def get_key_usage_details_last_24h(self, key: str) -> dict | None:
|
||||
"""
|
||||
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
|
||||
|
||||
Args:
|
||||
key: 要查询的 API 密钥。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是模型名称,值是调用次数。
|
||||
如果查询出错或没有找到记录,可能返回 None 或空字典。
|
||||
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
||||
"""
|
||||
logger.info(
|
||||
f"Fetching usage details for key ending in ...{key[-4:]} for the last 24h."
|
||||
)
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=24)
|
||||
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.model_name, func.count(RequestLog.id).label("call_count")
|
||||
)
|
||||
.where(
|
||||
RequestLog.api_key == key,
|
||||
RequestLog.request_time >= cutoff_time,
|
||||
RequestLog.model_name.isnot(None), # Ensure model_name is not null
|
||||
)
|
||||
.group_by(RequestLog.model_name)
|
||||
.order_by(func.count(RequestLog.id).desc()) # Order by count descending
|
||||
)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
if not results:
|
||||
logger.info(
|
||||
f"No usage details found for key ending in ...{key[-4:]} in the last 24h."
|
||||
)
|
||||
return {} # Return empty dict if no records found
|
||||
|
||||
usage_details = {row["model_name"]: row["call_count"] for row in results}
|
||||
logger.info(
|
||||
f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
|
||||
)
|
||||
return usage_details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get key usage details for key ending in ...{key[-4:]}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Depending on requirements, you might return None or raise the exception
|
||||
# Raising allows the route handler to return a 500 error.
|
||||
raise # Re-raise the exception
|
||||
108
app/service/update/update_service.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import httpx
|
||||
from packaging import version
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_update_logger
|
||||
|
||||
logger = get_update_logger()
|
||||
|
||||
# GitHub repository details are read from settings (defined in app/config/config.py or environment variables)
|
||||
|
||||
# GITHUB_API_URL will be constructed inside the function to ensure settings are loaded
|
||||
|
||||
VERSION_FILE_PATH = "VERSION" # Path relative to project root
|
||||
|
||||
async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
通过比较当前版本与最新的 GitHub release 来检查应用程序更新。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], Optional[str]]: 一个元组,包含:
|
||||
- bool: 如果有可用更新则为 True,否则为 False。
|
||||
- Optional[str]: 如果有可用更新,则为最新的版本字符串,否则为 None。
|
||||
- Optional[str]: 如果检查失败,则为错误消息,否则为 None。
|
||||
"""
|
||||
try:
|
||||
# Read current version from VERSION file
|
||||
# Ensure the path is correct relative to the execution context or use absolute path if needed
|
||||
# Assuming execution from project root d:/develop/pythonProjects/gemini-balance
|
||||
with open(VERSION_FILE_PATH, 'r', encoding='utf-8') as f:
|
||||
current_v = f.read().strip()
|
||||
if not current_v:
|
||||
logger.error(f"VERSION file ('{VERSION_FILE_PATH}') is empty.")
|
||||
return False, None, f"VERSION file ('{VERSION_FILE_PATH}') is empty."
|
||||
except FileNotFoundError:
|
||||
logger.error(f"VERSION file not found at '{VERSION_FILE_PATH}'. Make sure it exists in the project root.")
|
||||
return False, None, f"VERSION file not found at '{VERSION_FILE_PATH}'."
|
||||
except IOError as e:
|
||||
logger.error(f"Error reading VERSION file ('{VERSION_FILE_PATH}'): {e}")
|
||||
return False, None, f"Error reading VERSION file ('{VERSION_FILE_PATH}')."
|
||||
|
||||
logger.info(f"当前应用程序版本 (from {VERSION_FILE_PATH}): {current_v}")
|
||||
|
||||
# Check if repository details are configured in settings
|
||||
if not settings.GITHUB_REPO_OWNER or not settings.GITHUB_REPO_NAME or \
|
||||
settings.GITHUB_REPO_OWNER == "your_owner" or settings.GITHUB_REPO_NAME == "your_repo":
|
||||
logger.warning("GitHub repository owner/name not configured in settings. Skipping update check.")
|
||||
return False, None, "Update check skipped: Repository not configured in settings."
|
||||
|
||||
# Construct the API URL inside the function to ensure settings are loaded
|
||||
github_api_url = f"https://api.github.com/repos/{settings.GITHUB_REPO_OWNER}/{settings.GITHUB_REPO_NAME}/releases/latest"
|
||||
logger.debug(f"Checking for updates at URL: {github_api_url}") # Log the URL for debugging
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# 添加 User-Agent 头,GitHub API 可能需要
|
||||
headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": f"{settings.GITHUB_REPO_NAME}-UpdateChecker/1.0" # Use repo name from settings for User-Agent
|
||||
}
|
||||
response = await client.get(github_api_url, headers=headers) # Use the locally constructed URL
|
||||
response.raise_for_status() # 对错误的 HTTP 状态码(4xx 或 5xx)抛出异常
|
||||
|
||||
latest_release = response.json()
|
||||
latest_v_str = latest_release.get("tag_name")
|
||||
|
||||
if not latest_v_str:
|
||||
logger.warning("在最新的 GitHub release 响应中找不到 'tag_name'。")
|
||||
return False, None, "无法从 GitHub 解析最新版本。"
|
||||
|
||||
# 移除 tag 名称中可能存在的 'v' 前缀
|
||||
if latest_v_str.startswith('v'):
|
||||
latest_v_str = latest_v_str[1:]
|
||||
|
||||
logger.info(f"在 GitHub 上找到的最新版本: {latest_v_str}")
|
||||
|
||||
# 比较版本
|
||||
current_version = version.parse(current_v)
|
||||
latest_version = version.parse(latest_v_str)
|
||||
|
||||
if latest_version > current_version:
|
||||
logger.info(f"有可用更新: {current_v} -> {latest_v_str}")
|
||||
return True, latest_v_str, None
|
||||
else:
|
||||
logger.info("应用程序已是最新版本。")
|
||||
return False, None, None
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"检查更新时发生 HTTP 错误: {e.response.status_code} - {e.response.text}")
|
||||
# 避免向用户显示详细的错误文本
|
||||
error_msg = f"获取更新信息失败 (HTTP {e.response.status_code})。"
|
||||
if e.response.status_code == 404:
|
||||
error_msg += " 请检查仓库名称是否正确或仓库是否有发布版本。"
|
||||
elif e.response.status_code == 403:
|
||||
error_msg += " API 速率限制或权限问题。"
|
||||
return False, None, error_msg
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"检查更新时发生网络错误: {e}")
|
||||
return False, None, "更新检查期间发生网络错误。"
|
||||
except version.InvalidVersion:
|
||||
# Note: latest_v_str might not be defined if the error occurs before fetching it.
|
||||
# Consider adding a check or default value for logging.
|
||||
latest_v_str_for_log = latest_v_str if 'latest_v_str' in locals() else 'N/A'
|
||||
logger.error(f"发现无效的版本格式。当前 (from {VERSION_FILE_PATH}): '{current_v}', 最新: '{latest_v_str_for_log}'")
|
||||
return False, None, "遇到无效的版本格式。"
|
||||
except Exception as e:
|
||||
logger.error(f"更新检查期间发生意外错误: {e}", exc_info=True)
|
||||
return False, None, "发生意外错误。"
|
||||
@@ -1,59 +0,0 @@
|
||||
# app/services/chat/api_client.py
|
||||
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
import httpx
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ApiClient(ABC):
|
||||
"""API客户端基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
pass
|
||||
|
||||
|
||||
class GeminiApiClient(ApiClient):
|
||||
"""Gemini API客户端"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: int = 300):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
def _get_real_model(self, model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
|
||||
return model
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
@@ -1,165 +0,0 @@
|
||||
# app/services/chat/message_converter.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
import base64
|
||||
|
||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""消息转换器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
def _get_mime_type_and_data(base64_string):
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据。
|
||||
|
||||
参数:
|
||||
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
返回:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith('data:'):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = r'data:([^;]+);base64,(.+)'
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = match.group(1)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||
if image_url.startswith("data:image"):
|
||||
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
||||
return {
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": encoded_data
|
||||
}
|
||||
}
|
||||
return {
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _convert_image_to_base64(url: str) -> str:
|
||||
"""
|
||||
将图片URL转换为base64编码
|
||||
Args:
|
||||
url: 图片URL
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
"""
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
|
||||
|
||||
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理可能包含图片URL的文本,提取图片并转换为base64
|
||||
|
||||
Args:
|
||||
text: 可能包含图片URL的文本
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
||||
"""
|
||||
parts = []
|
||||
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
||||
if img_url_match:
|
||||
# 提取URL
|
||||
img_url = img_url_match.group(1)
|
||||
# 将URL对应的图片转换为base64
|
||||
try:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append({
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": base64_data
|
||||
}
|
||||
})
|
||||
except Exception:
|
||||
# 如果转换失败,回退到文本模式
|
||||
parts.append({"text": text})
|
||||
else:
|
||||
# 没有图片URL,作为纯文本处理
|
||||
parts.append({"text": text})
|
||||
return parts
|
||||
|
||||
|
||||
class OpenAIMessageConverter(MessageConverter):
|
||||
"""OpenAI消息格式转换器"""
|
||||
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction_parts = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
if role not in SUPPORTED_ROLES:
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
else:
|
||||
# 如果是最后一条消息,则认为是用户消息
|
||||
if idx == len(messages) - 1:
|
||||
role = "user"
|
||||
else:
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
# 特别处理最后一个assistant的消息,按\n\n分割
|
||||
if role == "assistant" and idx == len(messages) - 2 and isinstance(msg["content"], str) and msg["content"]:
|
||||
# 按\n\n分割消息
|
||||
content_parts = msg["content"].split("\n\n")
|
||||
for part in content_parts:
|
||||
if not part.strip(): # 跳过空内容
|
||||
continue
|
||||
# 处理可能包含图片的文本
|
||||
parts.extend(_process_text_with_image(part))
|
||||
elif isinstance(msg["content"], str) and msg["content"]:
|
||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||
parts.extend(_process_text_with_image(msg["content"]))
|
||||
elif isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict):
|
||||
if content["type"] == "text" and content["text"]:
|
||||
parts.append({"text": content["text"]})
|
||||
elif content["type"] == "image_url":
|
||||
parts.append(_convert_image(content["image_url"]["url"]))
|
||||
|
||||
if parts:
|
||||
if role == "system":
|
||||
system_instruction_parts.extend(parts)
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
system_instruction = (
|
||||
None
|
||||
if not system_instruction_parts
|
||||
else {
|
||||
"role": "system",
|
||||
"parts": system_instruction_parts,
|
||||
}
|
||||
)
|
||||
return converted_messages, system_instruction
|
||||
@@ -1,41 +0,0 @@
|
||||
# app/services/chat/retry_handler.py
|
||||
|
||||
from typing import TypeVar, Callable
|
||||
from functools import wraps
|
||||
from app.core.logger import get_retry_logger
|
||||
|
||||
T = TypeVar('T')
|
||||
logger = get_retry_logger()
|
||||
|
||||
|
||||
class RetryHandler:
|
||||
"""重试处理装饰器"""
|
||||
|
||||
def __init__(self, max_retries: int = 3, key_arg: str = "api_key"):
|
||||
self.max_retries = max_retries
|
||||
self.key_arg = key_arg
|
||||
|
||||
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> T:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
|
||||
|
||||
# 从函数参数中获取 key_manager
|
||||
key_manager = kwargs.get('key_manager')
|
||||
if key_manager:
|
||||
old_key = kwargs.get(self.key_arg)
|
||||
new_key = await key_manager.handle_api_failure(old_key)
|
||||
kwargs[self.key_arg] = new_key
|
||||
logger.info(f"Switched to new API key: {new_key}")
|
||||
|
||||
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import Union, List
|
||||
|
||||
import openai
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
|
||||
from app.core.logger import get_embeddings_logger
|
||||
|
||||
logger = get_embeddings_logger()
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def create_embedding(
|
||||
self, input_text: Union[str, List[str]], model: str, api_key: str
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""Create embeddings using OpenAI API"""
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(input=input_text, model=model)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedding: {str(e)}")
|
||||
raise
|
||||
@@ -1,149 +0,0 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List
|
||||
from app.core.logger import get_gemini_logger
|
||||
from app.services.chat.api_client import GeminiApiClient
|
||||
from app.services.chat.stream_optimizer import gemini_optimizer
|
||||
from app.schemas.gemini_models import GeminiRequest
|
||||
from app.core.config import settings
|
||||
from app.services.chat.response_handler import GeminiResponseHandler
|
||||
from app.services.key_manager import KeyManager
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片部分"""
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
tools = []
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
) and not _has_image_parts(payload.get("contents", [])):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
|
||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||
items = payload.get("tools", [])
|
||||
if items and isinstance(items, list):
|
||||
tools.extend(items)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
||||
]
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
"systemInstruction": request_dict.get("systemInstruction", "")
|
||||
}
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
|
||||
return payload
|
||||
|
||||
|
||||
class GeminiChatService:
|
||||
"""聊天服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager):
|
||||
self.api_client = GeminiApiClient(base_url)
|
||||
self.key_manager = key_manager
|
||||
self.response_handler = GeminiResponseHandler()
|
||||
|
||||
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
||||
"""从响应中提取文本内容"""
|
||||
if not response.get("candidates"):
|
||||
return ""
|
||||
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts and "text" in parts[0]:
|
||||
return parts[0].get("text", "")
|
||||
return ""
|
||||
|
||||
def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的响应"""
|
||||
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
||||
if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"):
|
||||
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
||||
return response_copy
|
||||
|
||||
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
payload = _build_payload(model, request)
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
|
||||
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
payload = _build_payload(model, request)
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(payload, model, api_key):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
line = line[6:]
|
||||
response_data = self.response_handler.handle_response(json.loads(line), model, stream=True)
|
||||
text = self._extract_text_from_response(response_data)
|
||||
|
||||
# 如果有文本内容,使用流式输出优化器处理
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for optimized_chunk in gemini_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_response(response_data, t),
|
||||
lambda c: "data: " + json.dumps(c) + "\n\n"
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如工具调用等),整块输出
|
||||
yield "data: " + json.dumps(response_data) + "\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
|
||||
api_key = await self.key_manager.handle_api_failure(api_key)
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
|
||||
break
|
||||
@@ -1,105 +0,0 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import Dict
|
||||
from app.core.logger import get_key_manager_logger
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = get_key_manager_logger()
|
||||
|
||||
|
||||
class KeyManager:
|
||||
def __init__(self, api_keys: list):
|
||||
self.api_keys = api_keys
|
||||
self.key_cycle = cycle(api_keys)
|
||||
self.key_cycle_lock = asyncio.Lock()
|
||||
self.failure_count_lock = asyncio.Lock()
|
||||
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
||||
self.MAX_FAILURES = settings.MAX_FAILURES
|
||||
self.paid_key = settings.PAID_KEY
|
||||
|
||||
async def get_paid_key(self) -> str:
|
||||
return self.paid_key
|
||||
|
||||
async def get_next_key(self) -> str:
|
||||
"""获取下一个API key"""
|
||||
async with self.key_cycle_lock:
|
||||
return next(self.key_cycle)
|
||||
|
||||
async def is_key_valid(self, key: str) -> bool:
|
||||
"""检查key是否有效"""
|
||||
async with self.failure_count_lock:
|
||||
return self.key_failure_counts[key] < self.MAX_FAILURES
|
||||
|
||||
async def reset_failure_counts(self):
|
||||
"""重置所有key的失败计数"""
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
self.key_failure_counts[key] = 0
|
||||
|
||||
async def get_next_working_key(self) -> str:
|
||||
"""获取下一可用的API key"""
|
||||
initial_key = await self.get_next_key()
|
||||
current_key = initial_key
|
||||
|
||||
while True:
|
||||
if await self.is_key_valid(current_key):
|
||||
return current_key
|
||||
|
||||
current_key = await self.get_next_key()
|
||||
if current_key == initial_key:
|
||||
# await self.reset_failure_counts() 取消重置
|
||||
return current_key
|
||||
|
||||
async def handle_api_failure(self, api_key: str) -> str:
|
||||
"""处理API调用失败"""
|
||||
async with self.failure_count_lock:
|
||||
self.key_failure_counts[api_key] += 1
|
||||
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
|
||||
return await self.get_next_working_key()
|
||||
|
||||
def get_fail_count(self, key: str) -> int:
|
||||
"""获取指定密钥的失败次数"""
|
||||
return self.key_failure_counts.get(key, 0)
|
||||
|
||||
async def get_keys_by_status(self) -> dict:
|
||||
"""获取分类后的API key列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
invalid_keys = {}
|
||||
|
||||
async with self.failure_count_lock:
|
||||
for key in self.api_keys:
|
||||
fail_count = self.key_failure_counts[key]
|
||||
if fail_count < self.MAX_FAILURES:
|
||||
valid_keys[key] = fail_count
|
||||
else:
|
||||
invalid_keys[key] = fail_count
|
||||
|
||||
return {
|
||||
"valid_keys": valid_keys,
|
||||
"invalid_keys": invalid_keys
|
||||
}
|
||||
|
||||
|
||||
_singleton_instance = None
|
||||
_singleton_lock = asyncio.Lock()
|
||||
|
||||
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
|
||||
"""
|
||||
获取 KeyManager 单例实例。
|
||||
|
||||
如果尚未创建实例,将使用提供的 api_keys 初始化 KeyManager。
|
||||
如果已创建实例,则忽略 api_keys 参数,返回现有单例。
|
||||
"""
|
||||
global _singleton_instance
|
||||
|
||||
async with _singleton_lock:
|
||||
if _singleton_instance is None:
|
||||
if api_keys is None:
|
||||
raise ValueError("API keys are required to initialize the KeyManager")
|
||||
_singleton_instance = KeyManager(api_keys)
|
||||
return _singleton_instance
|
||||
@@ -1,84 +0,0 @@
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
from app.core.logger import get_model_logger
|
||||
from app.core.config import settings
|
||||
|
||||
logger = get_model_logger()
|
||||
|
||||
class ModelService:
|
||||
def __init__(self, model_search: list, model_image: list):
|
||||
self.model_search = model_search
|
||||
self.model_image = model_image
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
url = f"{self.base_url}/models?key={api_key}"
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
return gemini_models
|
||||
else:
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(response.text)
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
gemini_models = self.get_gemini_models(api_key)
|
||||
return self.convert_to_openai_models_format(gemini_models)
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
def convert_to_openai_models_format(
|
||||
self, gemini_models: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
openai_format = {"object": "list", "data": [], "success": True}
|
||||
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
openai_model = {
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": int(datetime.now(timezone.utc).timestamp()),
|
||||
"owned_by": "google",
|
||||
"permission": [],
|
||||
"root": model["name"],
|
||||
"parent": None,
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
if model_id in self.model_search:
|
||||
search_model = openai_model.copy()
|
||||
search_model["id"] = f"{model_id}-search"
|
||||
openai_format["data"].append(search_model)
|
||||
if model_id in self.model_image:
|
||||
image_model = openai_model.copy()
|
||||
image_model["id"] = f"{model_id}-image"
|
||||
openai_format["data"].append(image_model)
|
||||
|
||||
if settings.CREATE_IMAGE_MODEL:
|
||||
image_model = openai_model.copy()
|
||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
openai_format["data"].append(image_model)
|
||||
return openai_format
|
||||
|
||||
def check_model_support(self, model: str) -> bool:
|
||||
if not model or not isinstance(model, str):
|
||||
return False
|
||||
|
||||
model = model.strip()
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
return model in settings.MODEL_SEARCH
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
return model in settings.MODEL_IMAGE
|
||||
|
||||
return True
|
||||
@@ -1,275 +0,0 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
||||
from app.core.logger import get_openai_logger
|
||||
from app.services.chat.message_converter import OpenAIMessageConverter
|
||||
from app.services.chat.response_handler import OpenAIResponseHandler
|
||||
from app.services.chat.api_client import GeminiApiClient
|
||||
from app.services.chat.stream_optimizer import openai_optimizer
|
||||
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.core.config import settings
|
||||
from app.services.image_create_service import ImageCreateService
|
||||
from app.services.key_manager import KeyManager
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片部分"""
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_tools(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
tools = []
|
||||
model = request.model
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation"))
|
||||
and not _has_image_parts(messages)
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
|
||||
# 将 request 中的 tools 合并到 tools 中
|
||||
if request.tools:
|
||||
function_declarations = []
|
||||
for tool in request.tools:
|
||||
if not tool or not isinstance(tool, dict):
|
||||
continue
|
||||
|
||||
if tool.get("type", "") == "function" and tool.get("function"):
|
||||
function = deepcopy(tool.get("function"))
|
||||
parameters = function.get("parameters", {})
|
||||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||||
function.pop("parameters", None)
|
||||
|
||||
function_declarations.append(function)
|
||||
|
||||
if function_declarations:
|
||||
# 按照 function 的 name 去重
|
||||
names, functions = set(), []
|
||||
for item in function_declarations:
|
||||
if item.get("name") not in names:
|
||||
names.add(item.get("name"))
|
||||
functions.append(item)
|
||||
|
||||
tools.append({"functionDeclarations": functions})
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
# if (
|
||||
# "2.0" in model
|
||||
# and "gemini-2.0-flash-thinking-exp" not in model
|
||||
# and "gemini-2.0-pro-exp" not in model
|
||||
# ):
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
"maxOutputTokens": request.max_tokens,
|
||||
"stopSequences": request.stop,
|
||||
"topP": request.top_p,
|
||||
"topK": request.top_k,
|
||||
},
|
||||
"tools": _build_tools(request, messages),
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
|
||||
|
||||
if (
|
||||
instruction
|
||||
and isinstance(instruction, dict)
|
||||
and instruction.get("role") == "system"
|
||||
and instruction.get("parts")
|
||||
and not request.model.endswith("-image")
|
||||
and not request.model.endswith("-image-generation")
|
||||
):
|
||||
payload["systemInstruction"] = instruction
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class OpenAIChatService:
|
||||
"""聊天服务"""
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.message_converter = OpenAIMessageConverter()
|
||||
self.response_handler = OpenAIResponseHandler(config=None)
|
||||
self.api_client = GeminiApiClient(base_url)
|
||||
self.key_manager = key_manager
|
||||
self.image_create_service = ImageCreateService()
|
||||
|
||||
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
||||
"""从OpenAI响应块中提取文本内容"""
|
||||
if not chunk.get("choices"):
|
||||
return ""
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
if "delta" in choice and "content" in choice["delta"]:
|
||||
return choice["delta"]["content"]
|
||||
return ""
|
||||
|
||||
def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的OpenAI响应块"""
|
||||
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
||||
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
||||
chunk_copy["choices"][0]["delta"]["content"] = text
|
||||
return chunk_copy
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
|
||||
# 构建请求payload
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, payload, api_key)
|
||||
return await self._handle_normal_completion(request.model, payload, api_key)
|
||||
|
||||
async def _handle_normal_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
return self.response_handler.handle_response(
|
||||
response, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, api_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = self.response_handler.handle_response(
|
||||
chunk, model, stream=True, finish_reason=None
|
||||
)
|
||||
if openai_chunk:
|
||||
# 提取文本内容
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for optimized_chunk in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n"
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如工具调用等),整块输出
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("Streaming completed successfully")
|
||||
break # 成功后退出循环
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
api_key = await self.key_manager.handle_api_failure(api_key)
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming. Raising error"
|
||||
)
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
|
||||
image_generate_request = ImageGenerationRequest()
|
||||
image_generate_request.prompt = request.messages[-1]["content"]
|
||||
image_res = self.image_create_service.generate_images_chat(image_generate_request)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_image_completion(request.model,image_res)
|
||||
else:
|
||||
return self._handle_normal_image_completion(request.model,image_res)
|
||||
|
||||
async def _handle_stream_image_completion(
|
||||
self, model: str, image_data: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if image_data:
|
||||
openai_chunk = self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=True, finish_reason=None
|
||||
)
|
||||
if openai_chunk:
|
||||
# 提取文本内容
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for optimized_chunk in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n"
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
# 如果没有文本内容(如图片URL等),整块输出
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("Image chat streaming completed successfully")
|
||||
|
||||
def _handle_normal_image_completion(
|
||||
self, model: str, image_data: str
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
return self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
@@ -1,249 +0,0 @@
|
||||
body {
|
||||
font-family: 'Roboto', sans-serif;
|
||||
line-height: 1.6;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 400px;
|
||||
width: 90%;
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
padding: 40px;
|
||||
border-radius: 20px;
|
||||
box-shadow: 0 15px 35px rgba(0,0,0,0.2);
|
||||
backdrop-filter: blur(10px);
|
||||
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
.container:hover {
|
||||
transform: translateY(-5px);
|
||||
box-shadow: 0 20px 40px rgba(0,0,0,0.25);
|
||||
}
|
||||
|
||||
.logo {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
animation: fadeIn 1s ease;
|
||||
}
|
||||
|
||||
.logo i {
|
||||
font-size: 48px;
|
||||
color: #764ba2;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
h2 {
|
||||
color: #2c3e50;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
font-weight: 700;
|
||||
font-size: 24px;
|
||||
animation: slideDown 0.5s ease;
|
||||
}
|
||||
|
||||
form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
position: relative;
|
||||
animation: slideUp 0.5s ease;
|
||||
}
|
||||
|
||||
.input-group i {
|
||||
position: absolute;
|
||||
left: 12px;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
color: #764ba2;
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
input {
|
||||
width: 100%;
|
||||
padding: 12px 12px 12px 40px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 10px;
|
||||
font-size: 16px;
|
||||
box-sizing: border-box;
|
||||
transition: all 0.3s ease;
|
||||
background: rgba(255, 255, 255, 0.9);
|
||||
}
|
||||
|
||||
input:focus {
|
||||
border-color: #764ba2;
|
||||
box-shadow: 0 0 10px rgba(118, 75, 162, 0.2);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
button {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 14px;
|
||||
border-radius: 10px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
font-weight: bold;
|
||||
transition: all 0.3s ease;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(118, 75, 162, 0.3);
|
||||
}
|
||||
|
||||
button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
button::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
width: 0;
|
||||
height: 0;
|
||||
background: rgba(255, 255, 255, 0.2);
|
||||
border-radius: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
transition: width 0.6s, height 0.6s;
|
||||
}
|
||||
|
||||
button:active::after {
|
||||
width: 200px;
|
||||
height: 200px;
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
color: #e74c3c;
|
||||
margin-top: 15px;
|
||||
text-align: center;
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
border-radius: 5px;
|
||||
background: rgba(231, 76, 60, 0.1);
|
||||
animation: shake 0.5s ease;
|
||||
}
|
||||
|
||||
.copyright {
|
||||
position: fixed;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
background: rgba(255, 255, 255, 0.9);
|
||||
padding: 10px 0;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
color: #2c3e50;
|
||||
backdrop-filter: blur(5px);
|
||||
border-top: 1px solid rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.copyright a {
|
||||
color: #764ba2;
|
||||
text-decoration: none;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.copyright a:hover {
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.copyright img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
vertical-align: middle;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
}
|
||||
|
||||
@keyframes slideDown {
|
||||
from { transform: translateY(-20px); opacity: 0; }
|
||||
to { transform: translateY(0); opacity: 1; }
|
||||
}
|
||||
|
||||
@keyframes slideUp {
|
||||
from { transform: translateY(20px); opacity: 0; }
|
||||
to { transform: translateY(0); opacity: 1; }
|
||||
}
|
||||
|
||||
@keyframes shake {
|
||||
0%, 100% { transform: translateX(0); }
|
||||
25% { transform: translateX(-5px); }
|
||||
75% { transform: translateX(5px); }
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.container {
|
||||
width: 85%;
|
||||
padding: 30px;
|
||||
}
|
||||
.logo i {
|
||||
font-size: 40px;
|
||||
}
|
||||
h2 {
|
||||
font-size: 22px;
|
||||
}
|
||||
input {
|
||||
padding: 10px 10px 10px 35px;
|
||||
font-size: 15px;
|
||||
}
|
||||
.input-group i {
|
||||
font-size: 16px;
|
||||
}
|
||||
button {
|
||||
padding: 12px;
|
||||
font-size: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.container {
|
||||
width: 90%;
|
||||
padding: 25px;
|
||||
}
|
||||
.logo i {
|
||||
font-size: 36px;
|
||||
}
|
||||
h2 {
|
||||
font-size: 20px;
|
||||
margin-bottom: 25px;
|
||||
}
|
||||
form {
|
||||
gap: 15px;
|
||||
}
|
||||
input {
|
||||
padding: 10px 10px 10px 32px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.input-group i {
|
||||
font-size: 15px;
|
||||
left: 10px;
|
||||
}
|
||||
button {
|
||||
padding: 10px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.error-message {
|
||||
font-size: 14px;
|
||||
padding: 8px;
|
||||
margin-top: 12px;
|
||||
}
|
||||
}
|
||||
@@ -1,461 +0,0 @@
|
||||
body {
|
||||
font-family: 'Roboto', sans-serif;
|
||||
line-height: 1.6;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 900px;
|
||||
width: 95%;
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
padding: 40px;
|
||||
border-radius: 20px;
|
||||
box-shadow: 0 15px 35px rgba(0,0,0,0.2);
|
||||
backdrop-filter: blur(10px);
|
||||
position: relative;
|
||||
margin: 20px auto;
|
||||
overflow-y: auto;
|
||||
max-height: calc(100vh - 40px);
|
||||
scrollbar-width: none;
|
||||
-ms-overflow-style: none;
|
||||
}
|
||||
|
||||
.container::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
font-weight: 700;
|
||||
font-size: 32px;
|
||||
position: relative;
|
||||
padding-bottom: 15px;
|
||||
}
|
||||
|
||||
h1::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
width: 100px;
|
||||
height: 4px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
.key-list {
|
||||
margin-bottom: 30px;
|
||||
background: rgba(248, 249, 250, 0.9);
|
||||
padding: 25px;
|
||||
border-radius: 15px;
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid rgba(0,0,0,0.1);
|
||||
animation: fadeIn 0.5s ease forwards;
|
||||
}
|
||||
|
||||
.key-list:hover {
|
||||
transform: translateY(-5px);
|
||||
box-shadow: 0 10px 20px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.key-list:nth-child(2) {
|
||||
animation-delay: 0.2s;
|
||||
}
|
||||
|
||||
.key-list h2 {
|
||||
color: #2c3e50;
|
||||
margin-bottom: 20px;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
font-size: 1.5em;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 2px solid rgba(0,0,0,0.1);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.key-list h2 .toggle-icon {
|
||||
margin-right: 10px;
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.key-list h2 .toggle-icon.collapsed {
|
||||
transform: rotate(-90deg);
|
||||
}
|
||||
|
||||
.key-list .key-content {
|
||||
transition: all 0.3s ease-out;
|
||||
overflow: hidden;
|
||||
height: auto;
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.key-list .key-content.collapsed {
|
||||
height: 0;
|
||||
opacity: 0;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
ul {
|
||||
list-style-type: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
li {
|
||||
background: white;
|
||||
border: 1px solid rgba(0,0,0,0.1);
|
||||
margin-bottom: 12px;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
transition: all 0.3s ease;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
box-shadow: 0 2px 5px rgba(0,0,0,0.05);
|
||||
}
|
||||
|
||||
li:hover {
|
||||
transform: translateX(5px);
|
||||
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.key-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.key-text {
|
||||
font-family: 'Roboto Mono', monospace;
|
||||
color: #2c3e50;
|
||||
}
|
||||
|
||||
.fail-count {
|
||||
background: rgba(231, 76, 60, 0.1);
|
||||
color: #e74c3c;
|
||||
padding: 4px 10px;
|
||||
border-radius: 15px;
|
||||
font-size: 0.85em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 5px;
|
||||
}
|
||||
|
||||
.fail-count i {
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.key-actions {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.verify-btn, .copy-btn {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 8px 15px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: bold;
|
||||
transition: all 0.3s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 5px;
|
||||
}
|
||||
|
||||
.verify-btn {
|
||||
background: linear-gradient(135deg, #2ecc71, #27ae60);
|
||||
}
|
||||
|
||||
.verify-btn:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(46, 204, 113, 0.3);
|
||||
}
|
||||
|
||||
.verify-btn:disabled {
|
||||
opacity: 0.7;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
.verify-btn i {
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.copy-btn:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(118, 75, 162, 0.3);
|
||||
}
|
||||
|
||||
.copy-btn:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.copy-btn i {
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.total {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 15px 25px;
|
||||
border-radius: 10px;
|
||||
font-weight: bold;
|
||||
text-align: center;
|
||||
font-size: 1.2em;
|
||||
margin-top: 30px;
|
||||
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
#copyStatus {
|
||||
position: fixed;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
padding: 15px 30px;
|
||||
border-radius: 25px;
|
||||
font-weight: bold;
|
||||
opacity: 0;
|
||||
transition: all 0.3s ease;
|
||||
backdrop-filter: blur(5px);
|
||||
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
|
||||
z-index: 1000;
|
||||
text-align: center;
|
||||
min-width: 200px;
|
||||
color: white;
|
||||
}
|
||||
|
||||
#copyStatus.success {
|
||||
background: rgba(39, 174, 96, 0.95);
|
||||
}
|
||||
|
||||
#copyStatus.error {
|
||||
background: rgba(231, 76, 60, 0.95);
|
||||
}
|
||||
|
||||
.status-badge {
|
||||
padding: 4px 12px;
|
||||
border-radius: 15px;
|
||||
font-size: 0.9em;
|
||||
font-weight: bold;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.status-valid {
|
||||
background: rgba(39, 174, 96, 0.1);
|
||||
color: #27ae60;
|
||||
}
|
||||
|
||||
.status-invalid {
|
||||
background: rgba(231, 76, 60, 0.1);
|
||||
color: #e74c3c;
|
||||
}
|
||||
|
||||
.scroll-buttons {
|
||||
position: fixed;
|
||||
right: 20px;
|
||||
bottom: 20px;
|
||||
display: none;
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.scroll-btn {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 20px;
|
||||
transition: all 0.3s ease;
|
||||
backdrop-filter: blur(5px);
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.scroll-btn:hover {
|
||||
background: linear-gradient(135deg, #764ba2 0%, #667eea 100%);
|
||||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
.scroll-btn:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.refresh-btn {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
z-index: 1000;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
border-radius: 25px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: bold;
|
||||
transition: all 0.3s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.refresh-btn:hover {
|
||||
transform: scale(1.05);
|
||||
box-shadow: 0 8px 20px rgba(118, 75, 162, 0.3);
|
||||
background: linear-gradient(135deg, #764ba2 0%, #667eea 100%);
|
||||
}
|
||||
|
||||
.refresh-btn:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.refresh-btn i {
|
||||
transition: transform 0.5s ease;
|
||||
}
|
||||
|
||||
.refresh-btn.loading i {
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
.copyright {
|
||||
position: fixed;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
background: rgba(255, 255, 255, 0.9);
|
||||
padding: 10px 0;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
color: #2c3e50;
|
||||
backdrop-filter: blur(5px);
|
||||
border-top: 1px solid rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.copyright a {
|
||||
color: #764ba2;
|
||||
text-decoration: none;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.copyright a:hover {
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.copyright img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
vertical-align: middle;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
from { transform: rotate(0deg); }
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.container {
|
||||
width: 100%;
|
||||
padding: 20px;
|
||||
margin: 10px auto;
|
||||
}
|
||||
body {
|
||||
padding: 10px;
|
||||
}
|
||||
h1 {
|
||||
font-size: 24px;
|
||||
}
|
||||
.key-list h2 {
|
||||
font-size: 1.2em;
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
align-items: flex-start;
|
||||
}
|
||||
.key-info {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 8px;
|
||||
}
|
||||
li {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
.key-actions {
|
||||
width: 100%;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.verify-btn, .copy-btn {
|
||||
width: 100%;
|
||||
justify-content: center;
|
||||
}
|
||||
.key-text {
|
||||
word-break: break-all;
|
||||
}
|
||||
.scroll-buttons {
|
||||
right: 10px;
|
||||
bottom: 10px;
|
||||
}
|
||||
.scroll-btn {
|
||||
width: 35px;
|
||||
height: 35px;
|
||||
font-size: 16px;
|
||||
}
|
||||
.refresh-btn {
|
||||
top: 10px;
|
||||
right: 10px;
|
||||
padding: 8px 16px;
|
||||
font-size: 12px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.container {
|
||||
padding: 15px;
|
||||
}
|
||||
h1 {
|
||||
font-size: 20px;
|
||||
}
|
||||
.key-list {
|
||||
padding: 15px;
|
||||
}
|
||||
.status-badge {
|
||||
padding: 3px 8px;
|
||||
font-size: 0.8em;
|
||||
}
|
||||
.fail-count {
|
||||
font-size: 0.8em;
|
||||
}
|
||||
.total {
|
||||
font-size: 1em;
|
||||
padding: 12px 20px;
|
||||
}
|
||||
}
|
||||
BIN
app/static/icons/logo.png
Normal file
|
After Width: | Height: | Size: 39 KiB |
BIN
app/static/icons/logo1.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
@@ -1,18 +0,0 @@
|
||||
if ('serviceWorker' in navigator) {
|
||||
window.addEventListener('load', () => {
|
||||
navigator.serviceWorker.register('/static/service-worker.js')
|
||||
.then(registration => {
|
||||
console.log('ServiceWorker注册成功:', registration.scope);
|
||||
})
|
||||
.catch(error => {
|
||||
console.log('ServiceWorker注册失败:', error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const copyrightYear = document.querySelector('.copyright script');
|
||||
if (copyrightYear) {
|
||||
copyrightYear.textContent = new Date().getFullYear();
|
||||
}
|
||||
});
|
||||
1965
app/static/js/config_editor.js
Normal file
981
app/static/js/error_logs.js
Normal file
@@ -0,0 +1,981 @@
|
||||
// 错误日志页面JavaScript (Updated for new structure, no Bootstrap)
|
||||
|
||||
// 页面滚动功能
|
||||
function scrollToTop() {
|
||||
window.scrollTo({ top: 0, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollToBottom() {
|
||||
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
// API 调用辅助函数
|
||||
async function fetchAPI(url, options = {}) {
|
||||
try {
|
||||
const response = await fetch(url, options);
|
||||
|
||||
// Handle cases where response might be empty but still ok (e.g., 204 No Content for DELETE)
|
||||
if (response.status === 204) {
|
||||
return null; // Indicate success with no content
|
||||
}
|
||||
|
||||
let responseData;
|
||||
try {
|
||||
responseData = await response.json();
|
||||
} catch (e) {
|
||||
// Handle non-JSON responses if necessary, or assume error if JSON expected
|
||||
if (!response.ok) {
|
||||
// If response is not ok and not JSON, use statusText
|
||||
throw new Error(`HTTP error! status: ${response.status} - ${response.statusText}`);
|
||||
}
|
||||
// If response is ok but not JSON, maybe return raw text or handle differently
|
||||
// For now, let's assume successful non-JSON is not expected or handled later
|
||||
console.warn("Response was not JSON for URL:", url);
|
||||
return await response.text(); // Or handle as needed
|
||||
}
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
// Prefer error message from API response body if available
|
||||
const message = responseData?.detail || `HTTP error! status: ${response.status} - ${response.statusText}`;
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
return responseData; // Return parsed JSON data for successful responses
|
||||
|
||||
} catch (error) {
|
||||
// Catch network errors or errors thrown from above
|
||||
console.error('API Call Failed:', error.message, 'URL:', url, 'Options:', options);
|
||||
// Re-throw the error so the calling function knows the operation failed
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh function removed as the buttons are gone.
|
||||
// If refresh functionality is needed elsewhere, it can be triggered directly by calling loadErrorLogs().
|
||||
|
||||
// 全局状态管理
|
||||
let errorLogState = {
|
||||
currentPage: 1,
|
||||
pageSize: 10,
|
||||
logs: [], // 存储获取的日志
|
||||
sort: {
|
||||
field: 'id', // 默认按 ID 排序
|
||||
order: 'desc' // 默认降序
|
||||
},
|
||||
search: {
|
||||
key: '',
|
||||
error: '',
|
||||
errorCode: '',
|
||||
startDate: '',
|
||||
endDate: ''
|
||||
}
|
||||
};
|
||||
|
||||
// DOM Elements Cache
|
||||
let pageSizeSelector;
|
||||
// let refreshBtn; // Removed, as the button is deleted
|
||||
let tableBody;
|
||||
let paginationElement;
|
||||
let loadingIndicator;
|
||||
let noDataMessage;
|
||||
let errorMessage;
|
||||
let logDetailModal;
|
||||
let modalCloseBtns; // Collection of close buttons for the modal
|
||||
let keySearchInput;
|
||||
let errorSearchInput;
|
||||
let errorCodeSearchInput; // Added error code input
|
||||
let startDateInput;
|
||||
let endDateInput;
|
||||
let searchBtn;
|
||||
let pageInput;
|
||||
let goToPageBtn;
|
||||
let selectAllCheckbox; // 新增:全选复选框
|
||||
let copySelectedKeysBtn; // 新增:复制选中按钮
|
||||
let deleteSelectedBtn; // 新增:批量删除按钮
|
||||
let sortByIdHeader; // 新增:ID 排序表头
|
||||
let sortIcon; // 新增:排序图标
|
||||
let selectedCountSpan; // 新增:选中计数显示
|
||||
let deleteConfirmModal; // 新增:删除确认模态框
|
||||
let closeDeleteConfirmModalBtn; // 新增:关闭删除模态框按钮
|
||||
let cancelDeleteBtn; // 新增:取消删除按钮
|
||||
let confirmDeleteBtn; // 新增:确认删除按钮
|
||||
let deleteConfirmMessage; // 新增:删除确认消息元素
|
||||
let idsToDeleteGlobally = []; // 新增:存储待删除的ID
|
||||
|
||||
// Helper functions for initialization
|
||||
function cacheDOMElements() {
|
||||
pageSizeSelector = document.getElementById('pageSize');
|
||||
tableBody = document.getElementById('errorLogsTable');
|
||||
paginationElement = document.getElementById('pagination');
|
||||
loadingIndicator = document.getElementById('loadingIndicator');
|
||||
noDataMessage = document.getElementById('noDataMessage');
|
||||
errorMessage = document.getElementById('errorMessage');
|
||||
logDetailModal = document.getElementById('logDetailModal');
|
||||
modalCloseBtns = document.querySelectorAll('#closeLogDetailModalBtn, #closeModalFooterBtn');
|
||||
keySearchInput = document.getElementById('keySearch');
|
||||
errorSearchInput = document.getElementById('errorSearch');
|
||||
errorCodeSearchInput = document.getElementById('errorCodeSearch');
|
||||
startDateInput = document.getElementById('startDate');
|
||||
endDateInput = document.getElementById('endDate');
|
||||
searchBtn = document.getElementById('searchBtn');
|
||||
pageInput = document.getElementById('pageInput');
|
||||
goToPageBtn = document.getElementById('goToPageBtn');
|
||||
selectAllCheckbox = document.getElementById('selectAllCheckbox');
|
||||
copySelectedKeysBtn = document.getElementById('copySelectedKeysBtn');
|
||||
deleteSelectedBtn = document.getElementById('deleteSelectedBtn');
|
||||
sortByIdHeader = document.getElementById('sortById');
|
||||
if (sortByIdHeader) {
|
||||
sortIcon = sortByIdHeader.querySelector('i');
|
||||
}
|
||||
selectedCountSpan = document.getElementById('selectedCount');
|
||||
deleteConfirmModal = document.getElementById('deleteConfirmModal');
|
||||
closeDeleteConfirmModalBtn = document.getElementById('closeDeleteConfirmModalBtn');
|
||||
cancelDeleteBtn = document.getElementById('cancelDeleteBtn');
|
||||
confirmDeleteBtn = document.getElementById('confirmDeleteBtn');
|
||||
deleteConfirmMessage = document.getElementById('deleteConfirmMessage');
|
||||
}
|
||||
|
||||
function initializePageSizeControls() {
|
||||
if (pageSizeSelector) {
|
||||
pageSizeSelector.value = errorLogState.pageSize;
|
||||
pageSizeSelector.addEventListener('change', function() {
|
||||
errorLogState.pageSize = parseInt(this.value);
|
||||
errorLogState.currentPage = 1; // Reset to first page
|
||||
loadErrorLogs();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function initializeSearchControls() {
|
||||
if (searchBtn) {
|
||||
searchBtn.addEventListener('click', function() {
|
||||
errorLogState.search.key = keySearchInput ? keySearchInput.value.trim() : '';
|
||||
errorLogState.search.error = errorSearchInput ? errorSearchInput.value.trim() : '';
|
||||
errorLogState.search.errorCode = errorCodeSearchInput ? errorCodeSearchInput.value.trim() : '';
|
||||
errorLogState.search.startDate = startDateInput ? startDateInput.value : '';
|
||||
errorLogState.search.endDate = endDateInput ? endDateInput.value : '';
|
||||
errorLogState.currentPage = 1; // Reset to first page on new search
|
||||
loadErrorLogs();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function initializeModalControls() {
|
||||
// Log Detail Modal
|
||||
if (logDetailModal && modalCloseBtns) {
|
||||
modalCloseBtns.forEach(btn => {
|
||||
btn.addEventListener('click', closeLogDetailModal);
|
||||
});
|
||||
logDetailModal.addEventListener('click', function(event) {
|
||||
if (event.target === logDetailModal) {
|
||||
closeLogDetailModal();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Delete Confirm Modal
|
||||
if (closeDeleteConfirmModalBtn) {
|
||||
closeDeleteConfirmModalBtn.addEventListener('click', hideDeleteConfirmModal);
|
||||
}
|
||||
if (cancelDeleteBtn) {
|
||||
cancelDeleteBtn.addEventListener('click', hideDeleteConfirmModal);
|
||||
}
|
||||
if (confirmDeleteBtn) {
|
||||
confirmDeleteBtn.addEventListener('click', handleConfirmDelete);
|
||||
}
|
||||
if (deleteConfirmModal) {
|
||||
deleteConfirmModal.addEventListener('click', function(event) {
|
||||
if (event.target === deleteConfirmModal) {
|
||||
hideDeleteConfirmModal();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function initializePaginationJumpControls() {
|
||||
if (goToPageBtn && pageInput) {
|
||||
goToPageBtn.addEventListener('click', function() {
|
||||
const targetPage = parseInt(pageInput.value);
|
||||
if (!isNaN(targetPage) && targetPage >= 1) {
|
||||
errorLogState.currentPage = targetPage;
|
||||
loadErrorLogs();
|
||||
pageInput.value = '';
|
||||
} else {
|
||||
showNotification('请输入有效的页码', 'error', 2000);
|
||||
pageInput.value = '';
|
||||
}
|
||||
});
|
||||
pageInput.addEventListener('keypress', function(event) {
|
||||
if (event.key === 'Enter') {
|
||||
goToPageBtn.click();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function initializeActionControls() {
|
||||
if (deleteSelectedBtn) {
|
||||
deleteSelectedBtn.addEventListener('click', handleDeleteSelected);
|
||||
}
|
||||
if (sortByIdHeader) {
|
||||
sortByIdHeader.addEventListener('click', handleSortById);
|
||||
}
|
||||
// Bulk selection listeners are closely related to actions
|
||||
setupBulkSelectionListeners();
|
||||
}
|
||||
|
||||
// 页面加载完成后执行
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
cacheDOMElements();
|
||||
initializePageSizeControls();
|
||||
initializeSearchControls();
|
||||
initializeModalControls();
|
||||
initializePaginationJumpControls();
|
||||
initializeActionControls();
|
||||
|
||||
// Initial load of error logs
|
||||
loadErrorLogs();
|
||||
|
||||
// Add event listeners for copy buttons inside the modal and table
|
||||
// This needs to be called after initial render and potentially after each render if content is dynamic
|
||||
setupCopyButtons();
|
||||
});
|
||||
|
||||
// 新增:显示删除确认模态框
|
||||
function showDeleteConfirmModal(message) {
|
||||
if (deleteConfirmModal && deleteConfirmMessage) {
|
||||
deleteConfirmMessage.textContent = message;
|
||||
deleteConfirmModal.classList.add('show');
|
||||
document.body.style.overflow = 'hidden'; // Prevent body scrolling
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:隐藏删除确认模态框
|
||||
function hideDeleteConfirmModal() {
|
||||
if (deleteConfirmModal) {
|
||||
deleteConfirmModal.classList.remove('show');
|
||||
document.body.style.overflow = ''; // Restore body scrolling
|
||||
idsToDeleteGlobally = []; // 清空待删除ID
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:处理确认删除按钮点击
|
||||
function handleConfirmDelete() {
|
||||
if (idsToDeleteGlobally.length > 0) {
|
||||
performActualDelete(idsToDeleteGlobally);
|
||||
}
|
||||
hideDeleteConfirmModal(); // 关闭模态框
|
||||
}
|
||||
|
||||
// Fallback copy function using document.execCommand
|
||||
function fallbackCopyTextToClipboard(text) {
|
||||
const textArea = document.createElement("textarea");
|
||||
textArea.value = text;
|
||||
|
||||
// Avoid scrolling to bottom
|
||||
textArea.style.top = "0";
|
||||
textArea.style.left = "0";
|
||||
textArea.style.position = "fixed";
|
||||
|
||||
document.body.appendChild(textArea);
|
||||
textArea.focus();
|
||||
textArea.select();
|
||||
|
||||
let successful = false;
|
||||
try {
|
||||
successful = document.execCommand('copy');
|
||||
} catch (err) {
|
||||
console.error('Fallback copy failed:', err);
|
||||
successful = false;
|
||||
}
|
||||
|
||||
document.body.removeChild(textArea);
|
||||
return successful;
|
||||
}
|
||||
|
||||
// Helper function to handle feedback after copy attempt (both modern and fallback)
|
||||
function handleCopyResult(buttonElement, success) {
|
||||
const originalIcon = buttonElement.querySelector('i').className; // Store original icon class
|
||||
const iconElement = buttonElement.querySelector('i');
|
||||
if (success) {
|
||||
iconElement.className = 'fas fa-check text-success-500'; // Use checkmark icon class
|
||||
showNotification('已复制到剪贴板', 'success', 2000);
|
||||
} else {
|
||||
iconElement.className = 'fas fa-times text-danger-500'; // Use error icon class
|
||||
showNotification('复制失败', 'error', 3000);
|
||||
}
|
||||
setTimeout(() => { iconElement.className = originalIcon; }, success ? 2000 : 3000); // Restore original icon class
|
||||
}
|
||||
|
||||
// 新的内部辅助函数,封装实际的复制操作和反馈
|
||||
function _performCopy(text, buttonElement) {
|
||||
let copySuccess = false;
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
if (buttonElement) {
|
||||
handleCopyResult(buttonElement, true);
|
||||
} else {
|
||||
showNotification('已复制到剪贴板', 'success');
|
||||
}
|
||||
}).catch(err => {
|
||||
console.error('Clipboard API failed, attempting fallback:', err);
|
||||
copySuccess = fallbackCopyTextToClipboard(text);
|
||||
if (buttonElement) {
|
||||
handleCopyResult(buttonElement, copySuccess);
|
||||
} else {
|
||||
showNotification(copySuccess ? '已复制到剪贴板' : '复制失败', copySuccess ? 'success' : 'error');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.warn("Clipboard API not available or context insecure. Using fallback copy method.");
|
||||
copySuccess = fallbackCopyTextToClipboard(text);
|
||||
if (buttonElement) {
|
||||
handleCopyResult(buttonElement, copySuccess);
|
||||
} else {
|
||||
showNotification(copySuccess ? '已复制到剪贴板' : '复制失败', copySuccess ? 'success' : 'error');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function to set up copy button listeners (using modern API with fallback) - Updated to handle table copy buttons
|
||||
function setupCopyButtons(containerSelector = 'body') {
|
||||
// Find buttons within the specified container (defaults to body)
|
||||
const container = document.querySelector(containerSelector);
|
||||
if (!container) return;
|
||||
|
||||
const copyButtons = container.querySelectorAll('.copy-btn');
|
||||
copyButtons.forEach(button => {
|
||||
// Remove existing listener to prevent duplicates if called multiple times
|
||||
button.removeEventListener('click', handleCopyButtonClick);
|
||||
// Add the listener
|
||||
button.addEventListener('click', handleCopyButtonClick);
|
||||
});
|
||||
}
|
||||
|
||||
// Extracted click handler logic for reusability and removing listeners
|
||||
function handleCopyButtonClick() {
|
||||
const button = this; // 'this' refers to the button clicked
|
||||
const targetId = button.getAttribute('data-target');
|
||||
const textToCopyDirect = button.getAttribute('data-copy-text'); // For direct text copy (e.g., table key)
|
||||
let textToCopy = '';
|
||||
|
||||
if (textToCopyDirect) {
|
||||
textToCopy = textToCopyDirect;
|
||||
} else if (targetId) {
|
||||
const targetElement = document.getElementById(targetId);
|
||||
if (targetElement) {
|
||||
textToCopy = targetElement.textContent;
|
||||
} else {
|
||||
console.error('Target element not found:', targetId);
|
||||
showNotification('复制出错:找不到目标元素', 'error');
|
||||
return; // Exit if target element not found
|
||||
}
|
||||
} else {
|
||||
console.error('No data-target or data-copy-text attribute found on button:', button);
|
||||
showNotification('复制出错:未指定复制内容', 'error');
|
||||
return; // Exit if no source specified
|
||||
}
|
||||
|
||||
|
||||
if (textToCopy) {
|
||||
_performCopy(textToCopy, button); // 使用新的辅助函数
|
||||
} else {
|
||||
console.warn('No text found to copy for target:', targetId || 'direct text');
|
||||
showNotification('没有内容可复制', 'warning');
|
||||
}
|
||||
} // End of handleCopyButtonClick function
|
||||
|
||||
// 新增:设置批量选择相关的事件监听器
|
||||
function setupBulkSelectionListeners() {
|
||||
if (selectAllCheckbox) {
|
||||
selectAllCheckbox.addEventListener('change', handleSelectAllChange);
|
||||
}
|
||||
|
||||
if (tableBody) {
|
||||
// 使用事件委托处理行复选框的点击
|
||||
tableBody.addEventListener('change', handleRowCheckboxChange);
|
||||
}
|
||||
|
||||
if (copySelectedKeysBtn) {
|
||||
copySelectedKeysBtn.addEventListener('click', handleCopySelectedKeys);
|
||||
}
|
||||
|
||||
// 新增:为批量删除按钮添加事件监听器 (如果尚未添加)
|
||||
// 通常在 DOMContentLoaded 中添加一次即可
|
||||
// if (deleteSelectedBtn && !deleteSelectedBtn.hasListener) {
|
||||
// deleteSelectedBtn.addEventListener('click', handleDeleteSelected);
|
||||
// deleteSelectedBtn.hasListener = true; // 标记已添加
|
||||
// }
|
||||
}
|
||||
|
||||
// 新增:处理“全选”复选框变化的函数
|
||||
function handleSelectAllChange() {
|
||||
const isChecked = selectAllCheckbox.checked;
|
||||
const rowCheckboxes = tableBody.querySelectorAll('.row-checkbox');
|
||||
rowCheckboxes.forEach(checkbox => {
|
||||
checkbox.checked = isChecked;
|
||||
});
|
||||
updateSelectedState();
|
||||
}
|
||||
|
||||
// 新增:处理行复选框变化的函数 (事件委托)
|
||||
function handleRowCheckboxChange(event) {
|
||||
if (event.target.classList.contains('row-checkbox')) {
|
||||
updateSelectedState();
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:更新选中状态(计数、按钮状态、全选框状态)
|
||||
function updateSelectedState() {
|
||||
const rowCheckboxes = tableBody.querySelectorAll('.row-checkbox');
|
||||
const selectedCheckboxes = tableBody.querySelectorAll('.row-checkbox:checked');
|
||||
const selectedCount = selectedCheckboxes.length;
|
||||
|
||||
// 移除了数字显示,不再更新selectedCountSpan
|
||||
// 仍然更新复制按钮的禁用状态
|
||||
if (copySelectedKeysBtn) {
|
||||
copySelectedKeysBtn.disabled = selectedCount === 0;
|
||||
|
||||
// 可选:根据选中项数量更新按钮标题属性
|
||||
copySelectedKeysBtn.setAttribute('title', `复制${selectedCount}项选中密钥`);
|
||||
}
|
||||
// 新增:更新批量删除按钮的禁用状态
|
||||
if (deleteSelectedBtn) {
|
||||
deleteSelectedBtn.disabled = selectedCount === 0;
|
||||
deleteSelectedBtn.setAttribute('title', `删除${selectedCount}项选中日志`);
|
||||
}
|
||||
|
||||
// 更新“全选”复选框的状态
|
||||
if (selectAllCheckbox) {
|
||||
if (rowCheckboxes.length > 0 && selectedCount === rowCheckboxes.length) {
|
||||
selectAllCheckbox.checked = true;
|
||||
selectAllCheckbox.indeterminate = false;
|
||||
} else if (selectedCount > 0) {
|
||||
selectAllCheckbox.checked = false;
|
||||
selectAllCheckbox.indeterminate = true; // 部分选中状态
|
||||
} else {
|
||||
selectAllCheckbox.checked = false;
|
||||
selectAllCheckbox.indeterminate = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:处理“复制选中密钥”按钮点击的函数
|
||||
function handleCopySelectedKeys() {
|
||||
const selectedCheckboxes = tableBody.querySelectorAll('.row-checkbox:checked');
|
||||
const keysToCopy = [];
|
||||
selectedCheckboxes.forEach(checkbox => {
|
||||
const key = checkbox.getAttribute('data-key');
|
||||
if (key) {
|
||||
keysToCopy.push(key);
|
||||
}
|
||||
});
|
||||
|
||||
if (keysToCopy.length > 0) {
|
||||
const textToCopy = keysToCopy.join('\n'); // 每行一个密钥
|
||||
_performCopy(textToCopy, copySelectedKeysBtn); // 使用新的辅助函数
|
||||
} else {
|
||||
showNotification('没有选中的密钥可复制', 'warning');
|
||||
}
|
||||
}
|
||||
|
||||
// 修改:处理批量删除按钮点击的函数 - 改为显示模态框
|
||||
function handleDeleteSelected() {
|
||||
const selectedCheckboxes = tableBody.querySelectorAll('.row-checkbox:checked');
|
||||
const logIdsToDelete = [];
|
||||
selectedCheckboxes.forEach(checkbox => {
|
||||
const logId = checkbox.getAttribute('data-log-id'); // 需要在渲染时添加 data-log-id
|
||||
if (logId) {
|
||||
logIdsToDelete.push(parseInt(logId));
|
||||
}
|
||||
});
|
||||
|
||||
if (logIdsToDelete.length === 0) {
|
||||
showNotification('没有选中的日志可删除', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
if (logIdsToDelete.length === 0) {
|
||||
showNotification('没有选中的日志可删除', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
// 存储待删除ID并显示模态框
|
||||
idsToDeleteGlobally = logIdsToDelete;
|
||||
const message = `确定要删除选中的 ${logIdsToDelete.length} 条日志吗?此操作不可恢复!`;
|
||||
showDeleteConfirmModal(message);
|
||||
}
|
||||
|
||||
// 新增:执行实际的删除操作(提取自原 handleDeleteSelected 和 handleDeleteLogRow)
|
||||
async function performActualDelete(logIds) {
|
||||
if (!logIds || logIds.length === 0) return;
|
||||
|
||||
const isSingleDelete = logIds.length === 1;
|
||||
const url = isSingleDelete ? `/api/logs/errors/${logIds[0]}` : '/api/logs/errors';
|
||||
const method = 'DELETE';
|
||||
const body = isSingleDelete ? null : JSON.stringify({ ids: logIds });
|
||||
const headers = isSingleDelete ? {} : { 'Content-Type': 'application/json' };
|
||||
const options = {
|
||||
method: method,
|
||||
headers: headers,
|
||||
body: body, // fetchAPI handles null body correctly
|
||||
};
|
||||
|
||||
try {
|
||||
// Use fetchAPI for the delete request
|
||||
await fetchAPI(url, options); // fetchAPI returns null for 204 No Content
|
||||
|
||||
// If fetchAPI doesn't throw, the request was successful
|
||||
const successMessage = isSingleDelete ? `成功删除该日志` : `成功删除 ${logIds.length} 条日志`;
|
||||
showNotification(successMessage, 'success');
|
||||
// 取消全选
|
||||
if (selectAllCheckbox) selectAllCheckbox.checked = false;
|
||||
// 重新加载当前页数据
|
||||
loadErrorLogs();
|
||||
} catch (error) {
|
||||
console.error('批量删除错误日志失败:', error);
|
||||
showNotification(`批量删除失败: ${error.message}`, 'error', 5000);
|
||||
}
|
||||
}
|
||||
|
||||
// 修改:处理单行删除按钮点击的函数 - 改为显示模态框
|
||||
function handleDeleteLogRow(logId) {
|
||||
if (!logId) return;
|
||||
|
||||
// 存储待删除ID并显示模态框
|
||||
idsToDeleteGlobally = [parseInt(logId)]; // 存储为数组
|
||||
// 使用通用确认消息,不显示具体ID
|
||||
const message = `确定要删除这条日志吗?此操作不可恢复!`;
|
||||
showDeleteConfirmModal(message);
|
||||
}
|
||||
|
||||
// 新增:处理 ID 排序点击的函数
|
||||
function handleSortById() {
|
||||
if (errorLogState.sort.field === 'id') {
|
||||
// 如果当前是按 ID 排序,切换顺序
|
||||
errorLogState.sort.order = errorLogState.sort.order === 'asc' ? 'desc' : 'asc';
|
||||
} else {
|
||||
// 如果当前不是按 ID 排序,切换到按 ID 排序,默认为降序
|
||||
errorLogState.sort.field = 'id';
|
||||
errorLogState.sort.order = 'desc';
|
||||
}
|
||||
// 更新图标
|
||||
updateSortIcon();
|
||||
// 重新加载第一页数据
|
||||
errorLogState.currentPage = 1;
|
||||
loadErrorLogs();
|
||||
}
|
||||
|
||||
// 新增:更新排序图标的函数
|
||||
function updateSortIcon() {
|
||||
if (!sortIcon) return;
|
||||
// 移除所有可能的排序类
|
||||
sortIcon.classList.remove('fa-sort', 'fa-sort-up', 'fa-sort-down', 'text-gray-400', 'text-primary-600');
|
||||
|
||||
if (errorLogState.sort.field === 'id') {
|
||||
sortIcon.classList.add(errorLogState.sort.order === 'asc' ? 'fa-sort-up' : 'fa-sort-down');
|
||||
sortIcon.classList.add('text-primary-600'); // 高亮显示
|
||||
} else {
|
||||
// 如果不是按 ID 排序,显示默认图标
|
||||
sortIcon.classList.add('fa-sort', 'text-gray-400');
|
||||
}
|
||||
}
|
||||
|
||||
// 加载错误日志数据
|
||||
async function loadErrorLogs() {
|
||||
// 重置选择状态
|
||||
if (selectAllCheckbox) selectAllCheckbox.checked = false;
|
||||
if (selectAllCheckbox) selectAllCheckbox.indeterminate = false;
|
||||
updateSelectedState(); // 更新按钮状态和计数
|
||||
|
||||
showLoading(true);
|
||||
showError(false);
|
||||
showNoData(false);
|
||||
|
||||
const offset = (errorLogState.currentPage - 1) * errorLogState.pageSize;
|
||||
|
||||
try {
|
||||
// Construct the API URL with search and sort parameters
|
||||
let apiUrl = `/api/logs/errors?limit=${errorLogState.pageSize}&offset=${offset}`;
|
||||
// 添加排序参数
|
||||
apiUrl += `&sort_by=${errorLogState.sort.field}&sort_order=${errorLogState.sort.order}`;
|
||||
|
||||
// 添加搜索参数
|
||||
if (errorLogState.search.key) {
|
||||
apiUrl += `&key_search=${encodeURIComponent(errorLogState.search.key)}`;
|
||||
}
|
||||
if (errorLogState.search.error) {
|
||||
apiUrl += `&error_search=${encodeURIComponent(errorLogState.search.error)}`;
|
||||
}
|
||||
if (errorLogState.search.errorCode) { // Add error code to API request
|
||||
apiUrl += `&error_code_search=${encodeURIComponent(errorLogState.search.errorCode)}`;
|
||||
}
|
||||
if (errorLogState.search.startDate) {
|
||||
apiUrl += `&start_date=${encodeURIComponent(errorLogState.search.startDate)}`;
|
||||
}
|
||||
if (errorLogState.search.endDate) {
|
||||
apiUrl += `&end_date=${encodeURIComponent(errorLogState.search.endDate)}`;
|
||||
}
|
||||
|
||||
// Use fetchAPI to get logs
|
||||
const data = await fetchAPI(apiUrl);
|
||||
|
||||
// API 现在返回 { logs: [], total: count }
|
||||
// fetchAPI already parsed JSON
|
||||
if (data && Array.isArray(data.logs)) {
|
||||
errorLogState.logs = data.logs; // Store the list data (contains error_code)
|
||||
renderErrorLogs(errorLogState.logs);
|
||||
updatePagination(errorLogState.logs.length, data.total || -1); // Use total from response
|
||||
} else {
|
||||
// Handle unexpected data format even after successful fetch
|
||||
console.error('Unexpected API response format:', data);
|
||||
throw new Error('无法识别的API响应格式');
|
||||
}
|
||||
|
||||
showLoading(false);
|
||||
|
||||
if (errorLogState.logs.length === 0) {
|
||||
showNoData(true);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取错误日志失败:', error);
|
||||
showLoading(false);
|
||||
showError(true, error.message); // Show specific error message
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create HTML for a single log row
|
||||
function _createLogRowHtml(log, sequentialId) {
|
||||
// Format date
|
||||
let formattedTime = 'N/A';
|
||||
try {
|
||||
const requestTime = new Date(log.request_time);
|
||||
if (!isNaN(requestTime)) {
|
||||
formattedTime = requestTime.toLocaleString('zh-CN', {
|
||||
year: 'numeric', month: '2-digit', day: '2-digit',
|
||||
hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false
|
||||
});
|
||||
}
|
||||
} catch (e) { console.error("Error formatting date:", e); }
|
||||
|
||||
const errorCodeContent = log.error_code || '无';
|
||||
|
||||
const maskKey = (key) => {
|
||||
if (!key || key.length < 8) return key || '无';
|
||||
return `${key.substring(0, 4)}...${key.substring(key.length - 4)}`;
|
||||
};
|
||||
const maskedKey = maskKey(log.gemini_key);
|
||||
const fullKey = log.gemini_key || '';
|
||||
|
||||
return `
|
||||
<td class="text-center px-3 py-3">
|
||||
<input type="checkbox" class="row-checkbox form-checkbox h-4 w-4 text-primary-600 border-gray-300 rounded focus:ring-primary-500" data-key="${fullKey}" data-log-id="${log.id}">
|
||||
</td>
|
||||
<td>${sequentialId}</td>
|
||||
<td class="relative group" title="${fullKey}">
|
||||
${maskedKey}
|
||||
<button class="copy-btn absolute top-1/2 right-2 transform -translate-y-1/2 bg-gray-200 hover:bg-gray-300 text-gray-600 p-1 rounded opacity-0 group-hover:opacity-100 transition-opacity text-xs" data-copy-text="${fullKey}" title="复制完整密钥">
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</td>
|
||||
<td>${log.error_type || '未知'}</td>
|
||||
<td class="error-code-content" title="${log.error_code || ''}">${errorCodeContent}</td>
|
||||
<td>${log.model_name || '未知'}</td>
|
||||
<td>${formattedTime}</td>
|
||||
<td>
|
||||
<button class="btn-view-details mr-2" data-log-id="${log.id}">
|
||||
<i class="fas fa-eye mr-1"></i>详情
|
||||
</button>
|
||||
<button class="btn-delete-row text-danger-600 hover:text-danger-800" data-log-id="${log.id}" title="删除此日志">
|
||||
<i class="fas fa-trash-alt"></i>
|
||||
</button>
|
||||
</td>
|
||||
`;
|
||||
}
|
||||
|
||||
// 渲染错误日志表格
|
||||
function renderErrorLogs(logs) {
|
||||
if (!tableBody) return;
|
||||
tableBody.innerHTML = ''; // Clear previous entries
|
||||
|
||||
// 重置全选复选框状态(在清空表格后)
|
||||
if (selectAllCheckbox) {
|
||||
selectAllCheckbox.checked = false;
|
||||
selectAllCheckbox.indeterminate = false;
|
||||
}
|
||||
|
||||
if (!logs || logs.length === 0) {
|
||||
// Handled by showNoData
|
||||
return;
|
||||
}
|
||||
|
||||
const startIndex = (errorLogState.currentPage - 1) * errorLogState.pageSize;
|
||||
|
||||
logs.forEach((log, index) => {
|
||||
const sequentialId = startIndex + index + 1;
|
||||
const row = document.createElement('tr');
|
||||
row.innerHTML = _createLogRowHtml(log, sequentialId);
|
||||
tableBody.appendChild(row);
|
||||
});
|
||||
|
||||
// Add event listeners to new 'View Details' buttons
|
||||
document.querySelectorAll('.btn-view-details').forEach(button => {
|
||||
button.addEventListener('click', function() {
|
||||
const logId = parseInt(this.getAttribute('data-log-id'));
|
||||
showLogDetails(logId);
|
||||
});
|
||||
});
|
||||
|
||||
// 新增:为新渲染的删除按钮添加事件监听器
|
||||
document.querySelectorAll('.btn-delete-row').forEach(button => {
|
||||
button.addEventListener('click', function() {
|
||||
const logId = this.getAttribute('data-log-id');
|
||||
handleDeleteLogRow(logId);
|
||||
});
|
||||
});
|
||||
|
||||
// Re-initialize copy buttons specifically for the newly rendered table rows
|
||||
setupCopyButtons('#errorLogsTable');
|
||||
// Update selected state after rendering
|
||||
updateSelectedState();
|
||||
}
|
||||
|
||||
// 显示错误日志详情 (从 API 获取)
|
||||
async function showLogDetails(logId) {
|
||||
if (!logDetailModal) return;
|
||||
|
||||
// Show loading state in modal (optional)
|
||||
// Clear previous content and show a spinner or message
|
||||
document.getElementById('modalGeminiKey').textContent = '加载中...';
|
||||
document.getElementById('modalErrorType').textContent = '加载中...';
|
||||
document.getElementById('modalErrorLog').textContent = '加载中...';
|
||||
document.getElementById('modalRequestMsg').textContent = '加载中...';
|
||||
document.getElementById('modalModelName').textContent = '加载中...';
|
||||
document.getElementById('modalRequestTime').textContent = '加载中...';
|
||||
|
||||
logDetailModal.classList.add('show');
|
||||
document.body.style.overflow = 'hidden'; // Prevent body scrolling
|
||||
|
||||
try {
|
||||
// Use fetchAPI to get log details
|
||||
const logDetails = await fetchAPI(`/api/logs/errors/${logId}/details`);
|
||||
|
||||
// fetchAPI handles response.ok check and JSON parsing
|
||||
if (!logDetails) {
|
||||
// Handle case where API returns success but no data (if possible)
|
||||
throw new Error('未找到日志详情');
|
||||
}
|
||||
|
||||
// Format date
|
||||
let formattedTime = 'N/A';
|
||||
try {
|
||||
const requestTime = new Date(logDetails.request_time);
|
||||
if (!isNaN(requestTime)) {
|
||||
formattedTime = requestTime.toLocaleString('zh-CN', {
|
||||
year: 'numeric', month: '2-digit', day: '2-digit',
|
||||
hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false
|
||||
});
|
||||
}
|
||||
} catch (e) { console.error("Error formatting date:", e); }
|
||||
|
||||
// Format request message (handle potential JSON)
|
||||
let formattedRequestMsg = '无';
|
||||
if (logDetails.request_msg) {
|
||||
try {
|
||||
if (typeof logDetails.request_msg === 'object' && logDetails.request_msg !== null) {
|
||||
formattedRequestMsg = JSON.stringify(logDetails.request_msg, null, 2);
|
||||
} else if (typeof logDetails.request_msg === 'string') {
|
||||
// Try parsing if it looks like JSON, otherwise display as string
|
||||
const trimmedMsg = logDetails.request_msg.trim();
|
||||
if (trimmedMsg.startsWith('{') || trimmedMsg.startsWith('[')) {
|
||||
formattedRequestMsg = JSON.stringify(JSON.parse(logDetails.request_msg), null, 2);
|
||||
} else {
|
||||
formattedRequestMsg = logDetails.request_msg;
|
||||
}
|
||||
} else {
|
||||
formattedRequestMsg = String(logDetails.request_msg);
|
||||
}
|
||||
} catch (e) {
|
||||
formattedRequestMsg = String(logDetails.request_msg); // Fallback
|
||||
console.warn("Could not parse request_msg as JSON:", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate modal content with fetched details
|
||||
document.getElementById('modalGeminiKey').textContent = logDetails.gemini_key || '无';
|
||||
document.getElementById('modalErrorType').textContent = logDetails.error_type || '未知';
|
||||
document.getElementById('modalErrorLog').textContent = logDetails.error_log || '无'; // Full error log
|
||||
document.getElementById('modalRequestMsg').textContent = formattedRequestMsg; // Full request message
|
||||
document.getElementById('modalModelName').textContent = logDetails.model_name || '未知';
|
||||
document.getElementById('modalRequestTime').textContent = formattedTime;
|
||||
|
||||
// Re-initialize copy buttons specifically for the modal after content is loaded
|
||||
setupCopyButtons('#logDetailModal');
|
||||
|
||||
} catch (error) {
|
||||
console.error('获取日志详情失败:', error);
|
||||
// Show error in modal
|
||||
document.getElementById('modalGeminiKey').textContent = '错误';
|
||||
document.getElementById('modalErrorType').textContent = '错误';
|
||||
document.getElementById('modalErrorLog').textContent = `加载失败: ${error.message}`;
|
||||
document.getElementById('modalRequestMsg').textContent = '错误';
|
||||
document.getElementById('modalModelName').textContent = '错误';
|
||||
document.getElementById('modalRequestTime').textContent = '错误';
|
||||
// Optionally show a notification
|
||||
showNotification(`加载日志详情失败: ${error.message}`, 'error', 5000);
|
||||
}
|
||||
}
|
||||
|
||||
// Close Log Detail Modal
|
||||
function closeLogDetailModal() {
|
||||
if (logDetailModal) {
|
||||
logDetailModal.classList.remove('show');
|
||||
// Optional: Restore body scrolling
|
||||
document.body.style.overflow = '';
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 更新分页控件
|
||||
function updatePagination(currentItemCount, totalItems) {
|
||||
if (!paginationElement) return;
|
||||
paginationElement.innerHTML = ''; // Clear existing pagination
|
||||
|
||||
// Calculate total pages only if totalItems is known and valid
|
||||
let totalPages = 1;
|
||||
if (totalItems >= 0) {
|
||||
totalPages = Math.max(1, Math.ceil(totalItems / errorLogState.pageSize));
|
||||
} else if (currentItemCount < errorLogState.pageSize && errorLogState.currentPage === 1) {
|
||||
// If less items than page size fetched on page 1, assume it's the only page
|
||||
totalPages = 1;
|
||||
} else {
|
||||
// If total is unknown and more items might exist, we can't build full pagination
|
||||
// We can show Prev/Next based on current page and if items were returned
|
||||
console.warn("Total item count unknown, pagination will be limited.");
|
||||
// Basic Prev/Next for unknown total
|
||||
addPaginationLink(paginationElement, '«', errorLogState.currentPage > 1, () => { errorLogState.currentPage--; loadErrorLogs(); });
|
||||
addPaginationLink(paginationElement, errorLogState.currentPage.toString(), true, null, true); // Current page number (non-clickable)
|
||||
addPaginationLink(paginationElement, '»', currentItemCount === errorLogState.pageSize, () => { errorLogState.currentPage++; loadErrorLogs(); }); // Next enabled if full page was returned
|
||||
return; // Exit here for limited pagination
|
||||
}
|
||||
|
||||
|
||||
const maxPagesToShow = 5; // Max number of page links to show
|
||||
let startPage = Math.max(1, errorLogState.currentPage - Math.floor(maxPagesToShow / 2));
|
||||
let endPage = Math.min(totalPages, startPage + maxPagesToShow - 1);
|
||||
|
||||
// Adjust startPage if endPage reaches the limit first
|
||||
if (endPage === totalPages) {
|
||||
startPage = Math.max(1, endPage - maxPagesToShow + 1);
|
||||
}
|
||||
|
||||
|
||||
// Previous Button
|
||||
addPaginationLink(paginationElement, '«', errorLogState.currentPage > 1, () => { errorLogState.currentPage--; loadErrorLogs(); });
|
||||
|
||||
// First Page Button
|
||||
if (startPage > 1) {
|
||||
addPaginationLink(paginationElement, '1', true, () => { errorLogState.currentPage = 1; loadErrorLogs(); });
|
||||
if (startPage > 2) {
|
||||
addPaginationLink(paginationElement, '...', false); // Ellipsis
|
||||
}
|
||||
}
|
||||
|
||||
// Page Number Buttons
|
||||
for (let i = startPage; i <= endPage; i++) {
|
||||
addPaginationLink(paginationElement, i.toString(), true, () => { errorLogState.currentPage = i; loadErrorLogs(); }, i === errorLogState.currentPage);
|
||||
}
|
||||
|
||||
// Last Page Button
|
||||
if (endPage < totalPages) {
|
||||
if (endPage < totalPages - 1) {
|
||||
addPaginationLink(paginationElement, '...', false); // Ellipsis
|
||||
}
|
||||
addPaginationLink(paginationElement, totalPages.toString(), true, () => { errorLogState.currentPage = totalPages; loadErrorLogs(); });
|
||||
}
|
||||
|
||||
|
||||
// Next Button
|
||||
addPaginationLink(paginationElement, '»', errorLogState.currentPage < totalPages, () => { errorLogState.currentPage++; loadErrorLogs(); });
|
||||
}
|
||||
|
||||
// Helper function to add pagination links
|
||||
function addPaginationLink(parentElement, text, enabled, clickHandler, isActive = false) {
|
||||
const pageItem = document.createElement('li');
|
||||
// 移除 'page-item' 和 'active' 类,使用 Tailwind 类进行样式化
|
||||
// pageItem.className = `page-item ${!enabled ? 'disabled' : ''} ${isActive ? 'active' : ''}`;
|
||||
|
||||
const pageLink = document.createElement('a');
|
||||
// 使用 Tailwind 类进行样式化
|
||||
pageLink.className = `px-3 py-1 rounded-md text-sm transition duration-150 ease-in-out ${
|
||||
isActive
|
||||
? 'bg-primary-600 text-white font-semibold shadow-md cursor-default' // 突出当前页样式
|
||||
: enabled
|
||||
? 'bg-white text-gray-700 hover:bg-primary-50 hover:text-primary-600 border border-gray-300' // 可点击页码样式
|
||||
: 'bg-gray-100 text-gray-400 cursor-not-allowed border border-gray-200' // 禁用状态样式 (如 '...')
|
||||
}`;
|
||||
pageLink.href = '#'; // Prevent page jump
|
||||
pageLink.innerHTML = text;
|
||||
|
||||
if (enabled && clickHandler) {
|
||||
pageLink.addEventListener('click', function(e) {
|
||||
e.preventDefault();
|
||||
clickHandler();
|
||||
});
|
||||
} else if (!enabled) {
|
||||
pageLink.addEventListener('click', e => e.preventDefault()); // Prevent click on disabled or active
|
||||
} else if (isActive) {
|
||||
pageLink.addEventListener('click', e => e.preventDefault()); // Prevent click on active page
|
||||
}
|
||||
|
||||
// 不再需要 li 元素,直接将 a 元素添加到父元素
|
||||
// pageItem.appendChild(pageLink);
|
||||
parentElement.appendChild(pageLink);
|
||||
}
|
||||
|
||||
|
||||
// 显示/隐藏状态指示器 (using 'active' class)
|
||||
function showLoading(show) {
|
||||
if (loadingIndicator) loadingIndicator.style.display = show ? 'block' : 'none';
|
||||
}
|
||||
|
||||
function showNoData(show) {
|
||||
if (noDataMessage) noDataMessage.style.display = show ? 'block' : 'none';
|
||||
}
|
||||
|
||||
function showError(show, message = '加载错误日志失败,请稍后重试。') {
|
||||
if (errorMessage) {
|
||||
errorMessage.style.display = show ? 'block' : 'none';
|
||||
if (show) {
|
||||
// Update the error message content
|
||||
const p = errorMessage.querySelector('p');
|
||||
if (p) p.textContent = message;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function to show temporary status notifications (like copy success)
|
||||
function showNotification(message, type = 'success', duration = 3000) {
|
||||
const notificationElement = document.getElementById('notification'); // Use the correct ID from base.html
|
||||
if (!notificationElement) {
|
||||
console.error("Notification element with ID 'notification' not found.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Set message and type class
|
||||
notificationElement.textContent = message;
|
||||
// Remove previous type classes before adding the new one
|
||||
notificationElement.classList.remove('success', 'error', 'warning', 'info');
|
||||
notificationElement.classList.add(type); // Add the type class for styling
|
||||
notificationElement.className = `notification ${type} show`; // Add 'show' class
|
||||
|
||||
// Hide after duration
|
||||
setTimeout(() => {
|
||||
notificationElement.classList.remove('show');
|
||||
}, duration);
|
||||
}
|
||||
|
||||
// Example Usage (if copy functionality is added later):
|
||||
// showNotification('密钥已复制!', 'success');
|
||||
// showNotification('复制失败!', 'error');
|
||||
@@ -17,13 +17,27 @@ self.addEventListener('install', event => {
|
||||
|
||||
self.addEventListener('fetch', event => {
|
||||
event.respondWith(
|
||||
caches.match(event.request)
|
||||
.then(response => {
|
||||
if (response) {
|
||||
return response;
|
||||
}
|
||||
return fetch(event.request);
|
||||
})
|
||||
caches.open(CACHE_NAME).then(cache => {
|
||||
// 1. 尝试从缓存获取
|
||||
return cache.match(event.request).then(responseFromCache => {
|
||||
// 2. 同时从网络获取 (后台进行)
|
||||
const fetchPromise = fetch(event.request).then(responseFromNetwork => {
|
||||
// 3. 网络请求成功,更新缓存
|
||||
cache.put(event.request, responseFromNetwork.clone());
|
||||
return responseFromNetwork;
|
||||
}).catch(err => {
|
||||
// 网络请求失败时,可以选择记录错误或不执行任何操作
|
||||
console.error('Network fetch failed:', err);
|
||||
// 确保即使网络失败,如果缓存存在,我们仍然返回缓存
|
||||
// 如果缓存也不存在,则此 Promise 会 reject
|
||||
throw err;
|
||||
});
|
||||
|
||||
// 4. 如果缓存存在,立即返回缓存;否则等待网络响应
|
||||
// 后台的网络请求仍在进行,用于更新缓存
|
||||
return responseFromCache || fetchPromise;
|
||||
});
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,42 +1,124 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>验证页面</title>
|
||||
<link rel="manifest" href="/static/manifest.json">
|
||||
<meta name="theme-color" content="#764ba2">
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black">
|
||||
<meta name="apple-mobile-web-app-title" content="GBalance">
|
||||
<link rel="icon" href="/static/icons/icon-192x192.png">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap" rel="stylesheet">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
|
||||
<link rel="stylesheet" href="/static/css/auth.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo">
|
||||
<i class="fas fa-shield-alt"></i>
|
||||
</div>
|
||||
<h2>安全验证</h2>
|
||||
<form id="auth-form" action="/auth" method="post">
|
||||
<div class="input-group">
|
||||
<i class="fas fa-key"></i>
|
||||
<input type="password" id="auth-token" name="auth_token" required placeholder="请输入验证令牌">
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}验证页面 - Gemini Balance{% endblock %}
|
||||
|
||||
{% block head_extra_styles %}
|
||||
<style>
|
||||
/* auth.html specific styles */
|
||||
.auth-glass-card { /* Renamed to avoid conflict if base.html has .glass-card */
|
||||
background: rgba(255, 255, 255, 0.85); /* Increased opacity */
|
||||
backdrop-filter: blur(20px);
|
||||
-webkit-backdrop-filter: blur(20px);
|
||||
border: 1px solid rgba(255, 255, 255, 0.2);
|
||||
}
|
||||
.auth-bg-gradient { /* Renamed to avoid conflict if base.html has .bg-gradient */
|
||||
background: linear-gradient(135deg, #4F46E5 0%, #7C3AED 50%, #EC4899 100%);
|
||||
}
|
||||
/* .input-icon class removed, using direct Tailwind classes now */
|
||||
/* Keep button ripple effect if needed, or remove if base provides similar */
|
||||
.auth-button { /* Renamed to avoid conflict */
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
.auth-button:after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
width: 0;
|
||||
height: 0;
|
||||
background: rgba(255, 255, 255, 0.2);
|
||||
border-radius: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
transition: width 0.6s, height 0.6s;
|
||||
}
|
||||
.auth-button:active:after {
|
||||
width: 300px;
|
||||
height: 300px;
|
||||
opacity: 0;
|
||||
}
|
||||
</style>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="auth-bg-gradient min-h-screen flex flex-col justify-center items-center p-4">
|
||||
<div class="glass-card rounded-2xl shadow-2xl p-10 max-w-md w-full mx-auto transform transition duration-500 hover:-translate-y-1 hover:shadow-3xl animate-fade-in">
|
||||
<div class="flex justify-center mb-8 animate-slide-down">
|
||||
<div class="rounded-full bg-primary-100 p-4 text-primary-600">
|
||||
<i class="fas fa-shield-alt text-4xl"></i>
|
||||
</div>
|
||||
<button type="submit">
|
||||
验证访问
|
||||
</div>
|
||||
|
||||
<h2 class="text-3xl font-extrabold text-center text-transparent bg-clip-text bg-gradient-to-r from-primary-600 to-primary-700 mb-8 animate-slide-down">
|
||||
<img src="/static/icons/logo.png" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
|
||||
Gemini Balance
|
||||
</h2>
|
||||
|
||||
<form id="auth-form" action="/auth" method="post" class="space-y-6 animate-slide-up">
|
||||
<div class="relative">
|
||||
<i class="fas fa-key absolute left-3 top-1/2 transform -translate-y-1/2 text-gray-500"></i>
|
||||
<input
|
||||
type="password"
|
||||
id="auth-token"
|
||||
name="auth_token"
|
||||
required
|
||||
placeholder="请输入验证令牌"
|
||||
class="w-full pl-10 pr-4 py-4 rounded-xl border border-gray-300 focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 transition duration-300 bg-white bg-opacity-90 text-gray-700"
|
||||
>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
class="w-full py-4 rounded-xl bg-gradient-to-r from-primary-600 to-primary-700 text-white font-semibold transition duration-300 transform hover:-translate-y-1 hover:shadow-lg"
|
||||
>
|
||||
登录
|
||||
</button>
|
||||
</form>
|
||||
|
||||
{% if error %}
|
||||
<p class="error-message">{{ error }}</p>
|
||||
<p class="mt-4 text-red-500 text-center font-medium p-3 bg-red-50 rounded-lg border border-red-200 animate-shake">
|
||||
{{ error }}
|
||||
</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
<div class="copyright">
|
||||
© <script>document.write(new Date().getFullYear())</script> by <a href="https://linux.do/u/snaily" target="_blank"><img src="https://linux.do/user_avatar/linux.do/snaily/288/306510_2.gif" alt="snaily">snaily</a> |
|
||||
<a href="https://github.com/snailyp/gemini-balance" target="_blank"><i class="fab fa-github"></i> GitHub</a>
|
||||
</div>
|
||||
<script src="/static/js/auth.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
</div> <!-- Close auth-bg-gradient div -->
|
||||
<!-- Notification placeholder for base.html's showNotification -->
|
||||
<div id="notification" class="notification"></div>
|
||||
|
||||
{% endblock %}
|
||||
|
||||
{% block body_scripts %}
|
||||
<script>
|
||||
// auth.html specific JavaScript
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const form = document.getElementById('auth-form');
|
||||
if (form) {
|
||||
form.addEventListener('submit', function(e) {
|
||||
const token = document.getElementById('auth-token').value.trim();
|
||||
if (!token) {
|
||||
e.preventDefault();
|
||||
// Use the base notification system
|
||||
showNotification('请输入验证令牌', 'error');
|
||||
}
|
||||
});
|
||||
}
|
||||
// Apply renamed classes
|
||||
document.querySelectorAll('button[type="submit"]').forEach(button => {
|
||||
button.classList.add('auth-button');
|
||||
});
|
||||
const card = document.querySelector('.auth-glass-card'); // Find the renamed card
|
||||
if (card) {
|
||||
// If the base template also defines .glass-card, remove it first
|
||||
// card.classList.remove('glass-card');
|
||||
} else {
|
||||
// If the card wasn't found by the new name, try the old name and rename
|
||||
const oldCard = document.querySelector('.glass-card');
|
||||
if (oldCard) {
|
||||
oldCard.classList.remove('glass-card');
|
||||
oldCard.classList.add('auth-glass-card');
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
377
app/templates/base.html
Normal file
@@ -0,0 +1,377 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>{% block title %}Gemini Balance{% endblock %}</title>
|
||||
<link rel="manifest" href="/static/manifest.json" />
|
||||
<meta name="theme-color" content="#4F46E5" />
|
||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black" />
|
||||
<meta name="apple-mobile-web-app-title" content="GBalance" />
|
||||
<link rel="icon" href="/static/icons/icon-192x192.png" />
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"
|
||||
/>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
primary: {
|
||||
50: "#eef2ff",
|
||||
100: "#e0e7ff",
|
||||
200: "#c7d2fe",
|
||||
300: "#a5b4fc",
|
||||
400: "#818cf8",
|
||||
500: "#6366f1",
|
||||
600: "#4f46e5",
|
||||
700: "#4338ca",
|
||||
800: "#3730a3",
|
||||
900: "#312e81",
|
||||
},
|
||||
success: {
|
||||
50: "#ecfdf5",
|
||||
500: "#10b981",
|
||||
600: "#059669",
|
||||
},
|
||||
danger: {
|
||||
50: "#fef2f2",
|
||||
500: "#ef4444",
|
||||
600: "#dc2626",
|
||||
},
|
||||
},
|
||||
fontFamily: {
|
||||
sans: ["Inter", "sans-serif"],
|
||||
mono: [
|
||||
"JetBrains Mono",
|
||||
"SFMono-Regular",
|
||||
"Menlo",
|
||||
"Monaco",
|
||||
"Consolas",
|
||||
"monospace",
|
||||
],
|
||||
},
|
||||
animation: {
|
||||
"fade-in": "fadeIn 0.5s ease-out",
|
||||
"slide-up": "slideUp 0.5s ease-out",
|
||||
"slide-down": "slideDown 0.5s ease-out",
|
||||
shake: "shake 0.5s ease-in-out",
|
||||
spin: "spin 1s linear infinite",
|
||||
},
|
||||
keyframes: {
|
||||
fadeIn: {
|
||||
"0%": { opacity: "0" },
|
||||
"100%": { opacity: "1" },
|
||||
},
|
||||
slideUp: {
|
||||
"0%": { transform: "translateY(20px)", opacity: "0" },
|
||||
"100%": { transform: "translateY(0)", opacity: "1" },
|
||||
},
|
||||
slideDown: {
|
||||
"0%": { transform: "translateY(-20px)", opacity: "0" },
|
||||
"100%": { transform: "translateY(0)", opacity: "1" },
|
||||
},
|
||||
shake: {
|
||||
"0%, 100%": { transform: "translateX(0)" },
|
||||
"25%": { transform: "translateX(-5px)" },
|
||||
"75%": { transform: "translateX(5px)" },
|
||||
},
|
||||
spin: {
|
||||
"0%": { transform: "rotate(0deg)" },
|
||||
"100%": { transform: "rotate(360deg)" },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
</script>
|
||||
<style>
|
||||
.glass-card {
|
||||
background: rgba(255, 255, 255, 0.85); /* Slightly increased opacity for better readability */
|
||||
backdrop-filter: blur(16px);
|
||||
-webkit-backdrop-filter: blur(16px);
|
||||
border: 1px solid rgba(255, 255, 255, 0.18); /* Subtle border */
|
||||
}
|
||||
.bg-gradient {
|
||||
background: linear-gradient(135deg, #4F46E5 0%, #7C3AED 50%, #EC4899 100%);
|
||||
}
|
||||
/* Scrollbar styling */
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
::-webkit-scrollbar-track {
|
||||
background: rgba(243, 244, 246, 0.8); /* bg-gray-100 with opacity */
|
||||
border-radius: 10px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: rgba(79, 70, 229, 0.4); /* primary-600 with opacity */
|
||||
border-radius: 10px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(79, 70, 229, 0.6); /* primary-600 with more opacity */
|
||||
}
|
||||
/* Basic modal styles */
|
||||
.modal {
|
||||
display: none;
|
||||
position: fixed;
|
||||
z-index: 50;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0,0,0,0.5);
|
||||
backdrop-filter: blur(4px);
|
||||
}
|
||||
.modal.show {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
/* Loading spinner */
|
||||
.loading-spin {
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
@keyframes spin {
|
||||
from { transform: rotate(0deg); }
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
/* Notification */
|
||||
.notification {
|
||||
position: fixed;
|
||||
bottom: 5rem; /* Adjusted from bottom-20 */
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
padding: 0.75rem 1.25rem; /* px-5 py-3 */
|
||||
border-radius: 0.5rem; /* rounded-lg */
|
||||
background-color: rgba(0, 0, 0, 0.8);
|
||||
color: white;
|
||||
font-weight: 500; /* font-medium */
|
||||
z-index: 1000; /* Increased z-index */
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s ease-in-out, transform 0.3s ease-in-out;
|
||||
}
|
||||
.notification.show {
|
||||
opacity: 1;
|
||||
transform: translate(-50%, 0);
|
||||
}
|
||||
.notification.error {
|
||||
background-color: rgba(220, 38, 38, 0.8); /* danger-600 with opacity */
|
||||
}
|
||||
/* Scroll buttons */
|
||||
.scroll-buttons {
|
||||
position: fixed;
|
||||
right: 1.25rem; /* right-5 */
|
||||
bottom: 5rem; /* bottom-20 */
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem; /* gap-2 */
|
||||
z-index: 10;
|
||||
}
|
||||
.scroll-button {
|
||||
width: 2.5rem; /* w-10 */
|
||||
height: 2.5rem; /* h-10 */
|
||||
background-color: #4f46e5; /* bg-primary-600 */
|
||||
color: white;
|
||||
border-radius: 9999px; /* rounded-full */
|
||||
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); /* shadow-md */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.3s ease-in-out;
|
||||
}
|
||||
.scroll-button:hover {
|
||||
background-color: #4338ca; /* hover:bg-primary-700 */
|
||||
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); /* hover:shadow-lg */
|
||||
}
|
||||
{% block head_extra_styles %}
|
||||
{% endblock %}
|
||||
</style>
|
||||
{% block head_extra_scripts %}{% endblock %}
|
||||
</head>
|
||||
<body class="bg-gradient min-h-screen text-gray-800 pt-6 pb-16">
|
||||
{% block content %}{% endblock %}
|
||||
|
||||
<!-- 底部版权 -->
|
||||
<div
|
||||
class="fixed bottom-0 left-0 w-full py-3 bg-white bg-opacity-80 backdrop-blur-md text-sm text-gray-800 border-t border-gray-200 flex flex-col items-center space-y-1"
|
||||
>
|
||||
<!-- 第一行 -->
|
||||
<div class="flex items-center justify-center space-x-2">
|
||||
<span>© <span id="copyright-year"></span> by</span>
|
||||
<a
|
||||
href="https://linux.do/u/snaily"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<img
|
||||
src="https://linux.do/user_avatar/linux.do/snaily/288/306510_2.gif"
|
||||
alt="snaily"
|
||||
class="inline-block w-5 h-5 rounded-full align-middle mr-1"
|
||||
/>snaily
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<a
|
||||
href="https://github.com/snailyp/gemini-balance"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fab fa-github mr-1"></i> GitHub
|
||||
</a>
|
||||
</div>
|
||||
<!-- 第二行 -->
|
||||
<div class="flex items-center justify-center space-x-2 text-xs">
|
||||
<a
|
||||
href="https://gb-docs.snaily.top/guide/supportme.html"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fas fa-drumstick-bite text-yellow-600 mr-1"></i> 给作者加鸡腿
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<a
|
||||
href="https://gb-docs.snaily.top"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fas fa-book mr-1"></i> 在线文档
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<a
|
||||
href="https://t.me/+soaHax5lyI0wZDVl"
|
||||
target="_blank"
|
||||
class="text-primary-600 hover:text-primary-800 transition duration-300 flex items-center"
|
||||
>
|
||||
<i class="fab fa-telegram-plane mr-1"></i> 交流群
|
||||
</a>
|
||||
<span class="text-gray-400">|</span>
|
||||
<span class="text-yellow-600 font-semibold flex items-center">
|
||||
<i class="fas fa-exclamation-triangle mr-1"></i>免费项目,谨防诈骗
|
||||
</span>
|
||||
<span id="version-info-container" class="inline-flex items-center">
|
||||
<!-- Version info will be loaded here by JavaScript -->
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 通用JS -->
|
||||
<script>
|
||||
// 设置版权年份
|
||||
document.getElementById("copyright-year").textContent =
|
||||
new Date().getFullYear();
|
||||
|
||||
// 滚动到顶部/底部函数 (如果页面需要)
|
||||
function scrollToTop() {
|
||||
window.scrollTo({ top: 0, behavior: "smooth" });
|
||||
}
|
||||
function scrollToBottom() {
|
||||
window.scrollTo({
|
||||
top: document.body.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}
|
||||
|
||||
// 显示通知
|
||||
function showNotification(message, type = "success", duration = 3000) {
|
||||
const notification =
|
||||
document.getElementById("notification") ||
|
||||
createNotificationElement();
|
||||
if (!notification) return;
|
||||
|
||||
notification.textContent = message;
|
||||
notification.className = "notification show"; // Reset classes
|
||||
if (type === "error") {
|
||||
notification.classList.add("error");
|
||||
}
|
||||
|
||||
// Clear previous timeout if exists
|
||||
if (notification.timeoutId) {
|
||||
clearTimeout(notification.timeoutId);
|
||||
}
|
||||
|
||||
notification.timeoutId = setTimeout(() => {
|
||||
notification.classList.remove("show");
|
||||
// Optional: remove the element after fade out if dynamically created
|
||||
// setTimeout(() => notification.remove(), 300);
|
||||
}, duration);
|
||||
}
|
||||
|
||||
// Helper to create notification element if it doesn't exist
|
||||
function createNotificationElement() {
|
||||
let notification = document.getElementById("notification");
|
||||
if (!notification) {
|
||||
notification = document.createElement("div");
|
||||
notification.id = "notification";
|
||||
notification.className = "notification";
|
||||
document.body.appendChild(notification);
|
||||
}
|
||||
return notification;
|
||||
}
|
||||
|
||||
// 页面刷新带加载状态
|
||||
function refreshPage(button) {
|
||||
if (button) {
|
||||
const icon = button.querySelector("i");
|
||||
if (icon) {
|
||||
icon.classList.add("loading-spin");
|
||||
}
|
||||
}
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 300); // Short delay to show spinner
|
||||
}
|
||||
|
||||
// --- Version Check ---
|
||||
const versionInfoContainer = document.getElementById(
|
||||
"version-info-container"
|
||||
);
|
||||
|
||||
async function fetchVersionInfo() {
|
||||
if (!versionInfoContainer) return;
|
||||
versionInfoContainer.innerHTML =
|
||||
'<span class="mx-1">|</span><span class="text-xs text-gray-700">检查更新中...</span>'; // Initial loading state
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/version/check");
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
|
||||
let versionHtml = `<span class="mx-1">|</span><span class="text-xs text-gray-800">v${data.current_version}</span>`;
|
||||
if (data.update_available) {
|
||||
versionHtml += `
|
||||
<span class="mx-1">|</span>
|
||||
<a href="https://github.com/snailyp/gemini-balance/releases/latest" target="_blank" class="text-yellow-600 hover:text-yellow-800 transition duration-300 animate-pulse">
|
||||
<i class="fas fa-arrow-up"></i> 新版本: v${data.latest_version}
|
||||
</a>`;
|
||||
} else if (data.error_message) {
|
||||
versionHtml += `
|
||||
<span class="mx-1">|</span>
|
||||
<span class="text-xs text-red-500" title="${data.error_message}">更新检查失败</span>`;
|
||||
} else {
|
||||
versionHtml += `<span class="mx-1">|</span><span class="text-xs text-green-500">已是最新</span>`; // Indicate up-to-date
|
||||
}
|
||||
versionInfoContainer.innerHTML = versionHtml;
|
||||
} catch (error) {
|
||||
console.error("Error fetching version info:", error);
|
||||
versionInfoContainer.innerHTML = `<span class="mx-1">|</span><span class="text-xs text-red-500" title="无法连接到服务器或解析响应">更新检查失败</span>`;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch immediately on load
|
||||
fetchVersionInfo();
|
||||
|
||||
// Fetch periodically (e.g., every hour)
|
||||
setInterval(fetchVersionInfo, 3600000); // 3600000 ms = 1 hour
|
||||
</script>
|
||||
{% block body_scripts %}{% endblock %}
|
||||
</body>
|
||||
</html>
|
||||
1772
app/templates/config_editor.html
Normal file
636
app/templates/error_logs.html
Normal file
@@ -0,0 +1,636 @@
|
||||
{% extends "base.html" %} {% block title %}错误日志管理 - Gemini Balance{%
|
||||
endblock %} {% block head_extra_styles %}
|
||||
<style>
|
||||
/* error_logs.html specific styles */
|
||||
.styled-table th {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background-color: rgba(80, 60, 160, 0.8); /* theming: table header bg */
|
||||
color: #ffffff !important; /* theming: table header text, ensured light */
|
||||
z-index: 10;
|
||||
border-bottom: 1px solid rgba(120, 100, 200, 0.4);
|
||||
}
|
||||
.styled-table tbody tr:hover {
|
||||
background-color: rgba(90, 70, 170, 0.4); /* theming: table row hover */
|
||||
}
|
||||
.styled-table td {
|
||||
padding: 12px 20px;
|
||||
vertical-align: middle;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 250px;
|
||||
color: #d1d5db; /* theming: table cell text (gray-300) */
|
||||
border-bottom: 1px solid rgba(120, 100, 200, 0.2); /* theming: cell border */
|
||||
}
|
||||
.styled-table td:nth-child(4) {
|
||||
white-space: nowrap;
|
||||
}
|
||||
.btn-view-details {
|
||||
background-color: rgba(107, 70, 193, 0.4); /* theming */
|
||||
color: #c4b5fd; /* theming */
|
||||
padding: 6px 12px;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
transition: all 0.2s ease-in-out;
|
||||
border: 1px solid rgba(120, 100, 200, 0.6); /* theming */
|
||||
}
|
||||
.btn-view-details:hover {
|
||||
background-color: rgba(120, 100, 200, 0.6); /* theming */
|
||||
color: #ede9fe; /* theming */
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
@media (max-width: 768px) {
|
||||
.search-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
input[type="text"],
|
||||
input[type="datetime-local"],
|
||||
select,
|
||||
button {
|
||||
height: 36px !important;
|
||||
}
|
||||
.form-input-themed,
|
||||
input[type="datetime-local"],
|
||||
select#pageSize {
|
||||
background-color: rgba(255, 255, 255, 0.1) !important;
|
||||
border-color: rgba(120, 100, 200, 0.5) !important;
|
||||
color: #ffffff !important;
|
||||
}
|
||||
.form-input-themed::placeholder,
|
||||
input[type="datetime-local"]::placeholder {
|
||||
color: #a0aec0 !important;
|
||||
}
|
||||
.form-input-themed:focus,
|
||||
input[type="datetime-local"]:focus,
|
||||
select#pageSize:focus {
|
||||
border-color: #a78bfa !important;
|
||||
box-shadow: 0 0 0 3px rgba(167, 139, 250, 0.4) !important;
|
||||
}
|
||||
select#pageSize {
|
||||
/* Styles from config_editor.html .form-select-themed, adapted for select#pageSize */
|
||||
background-color: rgba(60, 40, 130, 0.6) !important;
|
||||
border: 1px solid rgba(167, 139, 250, 0.7) !important;
|
||||
color: #ffffff !important;
|
||||
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%23d8b4fe' stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;
|
||||
appearance: none !important;
|
||||
padding: 0.6rem 2.5rem 0.6rem 0.8rem !important;
|
||||
background-repeat: no-repeat !important;
|
||||
background-position: right 0.6rem center !important;
|
||||
background-size: 1.5em 1.5em !important;
|
||||
border-radius: 0.5rem !important;
|
||||
font-weight: 500 !important;
|
||||
height: 36px !important; /* Retain original height or use auto */
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1) !important;
|
||||
cursor: pointer !important;
|
||||
}
|
||||
|
||||
select#pageSize:focus {
|
||||
border-color: #d8b4fe !important; /* violet-300 */
|
||||
box-shadow: 0 0 0 3px rgba(216, 180, 254, 0.4) !important; /* ring-violet-300 */
|
||||
outline: none !important;
|
||||
}
|
||||
|
||||
select#pageSize option {
|
||||
background-color: rgba(76, 29, 149, 0.95) !important; /* 暗紫色背景 */
|
||||
color: #ffffff !important;
|
||||
padding: 8px !important;
|
||||
}
|
||||
|
||||
.date-range-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
input[type="datetime-local"] {
|
||||
min-width: 0;
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
label {
|
||||
color: #e2e8f0 !important; /* Light gray/white for labels */
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* 导航链接悬停样式 (从 config_editor.html 复制) */
|
||||
.nav-link {
|
||||
transition: all 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
.nav-link:hover {
|
||||
background-color: rgba(120, 100, 200, 0.6) !important;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
/* Ensure text around pageSize select is light */
|
||||
.pagination-text {
|
||||
color: #e2e8f0 !important; /* Light gray/white for text */
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* Pagination custom styles */
|
||||
.pagination li a, .pagination li span { /* Assuming 'span' might be used for non-clickable items like '...' */
|
||||
display: flex; /* For centering content if icons are used */
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.5rem 0.75rem; /* Adjust padding as needed */
|
||||
line-height: 1.25;
|
||||
color: #e2e8f0; /* Light gray/white text */
|
||||
background-color: rgba(107, 70, 193, 0.4); /* Consistent with other buttons */
|
||||
border: 1px solid rgba(120, 100, 200, 0.6); /* Consistent with other buttons */
|
||||
border-radius: 0.375rem; /* Tailwind's rounded-md */
|
||||
transition: all 0.2s ease-in-out;
|
||||
min-width: 36px; /* Ensure minimum width for small numbers */
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.pagination li a:hover, .pagination li span:hover:not(.disabled) { /* Avoid hover on disabled spans */
|
||||
color: #ffffff;
|
||||
background-color: rgba(120, 100, 200, 0.6); /* Consistent with other button hovers */
|
||||
border-color: rgba(167, 139, 250, 0.8);
|
||||
}
|
||||
|
||||
.pagination li.active a, .pagination li.active span { /* Assuming 'active' class for current page */
|
||||
color: #ffffff !important;
|
||||
background-color: #7c3aed !important; /* Violet-600, ensure it overrides */
|
||||
border-color: #7c3aed !important;
|
||||
font-weight: 600; /* Make active page number bolder */
|
||||
}
|
||||
|
||||
.pagination li.disabled a, .pagination li.disabled span { /* Assuming 'disabled' class */
|
||||
color: rgba(226, 232, 240, 0.6) !important;
|
||||
background-color: rgba(80, 60, 160, 0.3) !important; /* Slightly more visible than pure disabled */
|
||||
border-color: rgba(120, 100, 200, 0.4) !important;
|
||||
cursor: not-allowed;
|
||||
pointer-events: none;
|
||||
}
|
||||
</style>
|
||||
{% endblock %} {% block content %}
|
||||
<div class="container mx-auto px-4">
|
||||
<div
|
||||
class="rounded-2xl shadow-xl p-6 md:p-8"
|
||||
style="
|
||||
background-color: rgba(80, 60, 160, 0.3);
|
||||
backdrop-filter: blur(10px);
|
||||
-webkit-backdrop-filter: blur(10px);
|
||||
border: 1px solid rgba(150, 130, 230, 0.3);
|
||||
"
|
||||
>
|
||||
<h1
|
||||
class="text-3xl font-extrabold text-center text-transparent bg-clip-text bg-gradient-to-r from-violet-400 to-pink-400 mb-4"
|
||||
>
|
||||
<img
|
||||
src="/static/icons/logo.png"
|
||||
alt="Gemini Balance Logo"
|
||||
class="h-9 inline-block align-middle mr-2"
|
||||
/>
|
||||
Gemini Balance - 错误日志
|
||||
</h1>
|
||||
|
||||
<!-- Navigation Tabs -->
|
||||
<div class="flex justify-center mb-8 overflow-x-auto pb-2 gap-2">
|
||||
<a
|
||||
href="/config"
|
||||
class="nav-link whitespace-nowrap flex items-center justify-center gap-2 px-6 py-3 font-medium rounded-lg text-gray-200 hover:text-white transition-all duration-200"
|
||||
style="background-color: rgba(107, 70, 193, 0.4)"
|
||||
>
|
||||
<i class="fas fa-cog"></i> 配置编辑
|
||||
</a>
|
||||
<a
|
||||
href="/keys"
|
||||
class="nav-link whitespace-nowrap flex items-center justify-center gap-2 px-6 py-3 font-medium rounded-lg text-gray-200 hover:text-white transition-all duration-200"
|
||||
style="background-color: rgba(107, 70, 193, 0.4)"
|
||||
>
|
||||
<i class="fas fa-tachometer-alt"></i> 监控面板
|
||||
</a>
|
||||
<a
|
||||
href="/logs"
|
||||
class="whitespace-nowrap flex items-center justify-center gap-2 px-6 py-3 font-medium rounded-lg bg-violet-600 text-white shadow-md"
|
||||
>
|
||||
<i class="fas fa-exclamation-triangle"></i> 错误日志
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<!-- 主内容区域 -->
|
||||
<div
|
||||
class="rounded-xl p-6 shadow-lg animate-fade-in"
|
||||
style="
|
||||
background-color: rgba(70, 50, 150, 0.5);
|
||||
backdrop-filter: blur(5px);
|
||||
-webkit-backdrop-filter: blur(5px);
|
||||
border: 1px solid rgba(120, 100, 200, 0.2);
|
||||
"
|
||||
>
|
||||
<h2
|
||||
class="text-xl font-bold mb-6 pb-3 border-b flex items-center gap-2 text-gray-100 border-violet-300 border-opacity-30"
|
||||
>
|
||||
<i class="fas fa-bug text-violet-400"></i> 错误日志列表
|
||||
</h2>
|
||||
|
||||
<!-- 搜索与操作控件 -->
|
||||
<div
|
||||
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6"
|
||||
>
|
||||
<div
|
||||
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full"
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
id="keySearch"
|
||||
placeholder="搜索密钥 (部分)"
|
||||
class="px-3 py-1 rounded-lg border form-input-themed"
|
||||
/>
|
||||
<input
|
||||
type="text"
|
||||
id="errorSearch"
|
||||
placeholder="搜索错误类型/日志"
|
||||
class="px-3 py-1 rounded-lg border form-input-themed"
|
||||
/>
|
||||
<input
|
||||
type="text"
|
||||
id="errorCodeSearch"
|
||||
placeholder="搜索错误码"
|
||||
class="px-3 py-1 rounded-lg border form-input-themed"
|
||||
/>
|
||||
<div
|
||||
class="grid grid-cols-1 sm:grid-cols-2 gap-2 col-span-1 sm:col-span-2 lg:col-span-3 mt-2"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<label class="text-sm text-gray-300 whitespace-nowrap"
|
||||
>开始时间:</label
|
||||
>
|
||||
<input
|
||||
type="datetime-local"
|
||||
id="startDate"
|
||||
class="px-3 py-1 rounded-lg border text-sm w-full"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<label class="text-sm text-gray-300 whitespace-nowrap"
|
||||
>结束时间:</label
|
||||
>
|
||||
<input
|
||||
type="datetime-local"
|
||||
id="endDate"
|
||||
class="px-3 py-1 rounded-lg border text-sm w-full"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3 flex-shrink-0">
|
||||
<button
|
||||
id="searchBtn"
|
||||
class="flex items-center justify-center px-4 py-1.5 bg-violet-600 hover:bg-violet-700 text-white rounded-lg font-medium transition-all duration-200 shadow-sm hover:shadow-md whitespace-nowrap"
|
||||
>
|
||||
<i class="fas fa-search mr-1.5"></i>搜索
|
||||
</button>
|
||||
<button
|
||||
id="copySelectedKeysBtn"
|
||||
class="flex items-center justify-center px-4 py-1.5 bg-sky-600 hover:bg-sky-700 text-white rounded-lg font-medium transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed shadow-sm hover:shadow-md whitespace-nowrap"
|
||||
disabled
|
||||
>
|
||||
<i class="far fa-copy mr-1.5"></i>复制
|
||||
</button>
|
||||
<button
|
||||
id="deleteSelectedBtn"
|
||||
class="flex items-center justify-center px-4 py-1.5 bg-red-600 hover:bg-red-700 text-white rounded-lg font-medium transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed shadow-sm hover:shadow-md whitespace-nowrap"
|
||||
disabled
|
||||
>
|
||||
<i class="fas fa-trash-alt mr-1.5"></i>删除
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 表格容器 -->
|
||||
<div
|
||||
class="overflow-x-auto rounded-lg border mb-6"
|
||||
style="border-color: rgba(120, 100, 200, 0.3)"
|
||||
>
|
||||
<table class="styled-table w-full min-w-full text-sm">
|
||||
<thead>
|
||||
<tr class="text-left">
|
||||
<th
|
||||
class="px-3 py-3 font-semibold rounded-tl-lg w-12 text-center"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
id="selectAllCheckbox"
|
||||
class="form-checkbox h-4 w-4 text-violet-500 border-gray-500 rounded focus:ring-violet-500 bg-transparent"
|
||||
/>
|
||||
</th>
|
||||
<th class="px-5 py-3 font-semibold cursor-pointer" id="sortById">
|
||||
ID <i class="fas fa-sort ml-1"></i>
|
||||
</th>
|
||||
<th class="px-5 py-3 font-semibold">Gemini密钥</th>
|
||||
<th class="px-5 py-3 font-semibold">错误类型</th>
|
||||
<th class="px-5 py-3 font-semibold">错误码</th>
|
||||
<th class="px-5 py-3 font-semibold">模型名称</th>
|
||||
<th class="px-5 py-3 font-semibold">请求时间</th>
|
||||
<th class="px-5 py-3 font-semibold rounded-tr-lg text-center">
|
||||
操作
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody
|
||||
id="errorLogsTable"
|
||||
class="divide-y"
|
||||
style="border-color: rgba(120, 100, 200, 0.2)"
|
||||
>
|
||||
<!-- 错误日志数据将通过JavaScript动态加载 -->
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- 状态指示器 -->
|
||||
<div
|
||||
id="loadingIndicator"
|
||||
class="flex items-center justify-center p-8 hidden"
|
||||
>
|
||||
<div
|
||||
class="animate-spin rounded-full h-12 w-12 border-b-2 border-violet-400"
|
||||
></div>
|
||||
<p class="ml-4 text-lg text-gray-300 font-medium">加载中,请稍候...</p>
|
||||
</div>
|
||||
|
||||
<div id="noDataMessage" class="text-center py-12 text-gray-400 hidden">
|
||||
<i class="fas fa-inbox text-5xl mb-3"></i>
|
||||
<p class="text-lg">暂无错误日志数据</p>
|
||||
</div>
|
||||
|
||||
<div
|
||||
id="errorMessage"
|
||||
class="p-4 rounded-lg font-medium text-center hidden"
|
||||
style="background-color: rgba(220, 38, 38, 0.2); color: #fca5a5"
|
||||
>
|
||||
<i class="fas fa-exclamation-circle mr-2"></i>
|
||||
加载错误日志失败,请稍后重试。
|
||||
</div>
|
||||
|
||||
<!-- 分页与每页显示控件 -->
|
||||
<div
|
||||
class="flex flex-col sm:flex-row justify-between items-center mt-6 gap-4"
|
||||
>
|
||||
<div class="flex items-center gap-2 text-sm text-gray-300">
|
||||
<label for="pageSize" class="font-medium pagination-text"
|
||||
>每页显示:</label
|
||||
>
|
||||
<select
|
||||
id="pageSize"
|
||||
class="rounded-md border focus:ring focus:border-violet-400 px-2 py-1 text-sm"
|
||||
>
|
||||
<option value="10">10</option>
|
||||
<option value="20" selected>20</option>
|
||||
<option value="50">50</option>
|
||||
<option value="100">100</option>
|
||||
</select>
|
||||
<span class="pagination-text">条</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-4">
|
||||
<ul class="pagination flex items-center gap-1" id="pagination">
|
||||
<!-- 分页控件将通过JavaScript动态加载 -->
|
||||
</ul>
|
||||
<div class="flex items-center gap-1">
|
||||
<input
|
||||
type="number"
|
||||
id="pageInput"
|
||||
min="1"
|
||||
class="w-16 px-2 py-1 rounded-md border text-sm focus:ring focus:border-violet-400 form-input-themed"
|
||||
placeholder="页码"
|
||||
/>
|
||||
<button
|
||||
id="goToPageBtn"
|
||||
class="px-3 py-1 bg-violet-600 hover:bg-violet-700 text-white text-sm rounded-md transition"
|
||||
>
|
||||
跳转
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Scroll buttons are now in base.html -->
|
||||
<div class="scroll-buttons">
|
||||
<button class="scroll-button" onclick="scrollToTop()" title="回到顶部">
|
||||
<i class="fas fa-chevron-up"></i>
|
||||
</button>
|
||||
<button class="scroll-button" onclick="scrollToBottom()" title="滚动到底部">
|
||||
<i class="fas fa-chevron-down"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Notification component is now in base.html (use id="notification") -->
|
||||
<div id="notification" class="notification"></div>
|
||||
<!-- Footer is now in base.html -->
|
||||
|
||||
<!-- 日志详情模态框 -->
|
||||
<div id="logDetailModal" class="modal">
|
||||
<div
|
||||
class="w-full max-w-6xl mx-auto rounded-2xl shadow-2xl overflow-hidden animate-fade-in"
|
||||
style="
|
||||
background-color: rgba(70, 50, 150, 0.95);
|
||||
color: #ffffff;
|
||||
border: 1px solid rgba(120, 100, 200, 0.4);
|
||||
"
|
||||
>
|
||||
<div class="p-6">
|
||||
<div
|
||||
class="flex justify-between items-center pb-4 mb-4"
|
||||
style="border-bottom: 1px solid rgba(120, 100, 200, 0.4)"
|
||||
>
|
||||
<h2 class="text-xl font-bold text-gray-100">错误日志详情</h2>
|
||||
<button
|
||||
id="closeLogDetailModalBtn"
|
||||
class="text-gray-300 hover:text-gray-100 text-xl"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="space-y-4 max-h-[60vh] overflow-y-auto p-1">
|
||||
<div
|
||||
class="p-4 rounded-lg relative group"
|
||||
style="background-color: rgba(80, 60, 160, 0.3)"
|
||||
>
|
||||
<h6 class="text-sm font-semibold text-violet-200 mb-1">
|
||||
Gemini密钥:
|
||||
</h6>
|
||||
<pre
|
||||
id="modalGeminiKey"
|
||||
class="font-mono text-sm p-3 rounded overflow-x-auto"
|
||||
style="background-color: rgba(0, 0, 0, 0.2); color: #e5e7eb"
|
||||
></pre>
|
||||
<button
|
||||
class="copy-btn absolute top-2 right-2 hover:bg-gray-600 text-gray-300 p-1.5 rounded opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
style="background-color: rgba(0, 0, 0, 0.3)"
|
||||
data-target="modalGeminiKey"
|
||||
title="复制密钥"
|
||||
>
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="p-4 rounded-lg relative group"
|
||||
style="background-color: rgba(80, 60, 160, 0.3)"
|
||||
>
|
||||
<h6 class="text-sm font-semibold text-violet-200 mb-1">错误类型:</h6>
|
||||
<p id="modalErrorType" class="text-red-300 font-medium pr-8"></p>
|
||||
<button
|
||||
class="copy-btn absolute top-2 right-2 hover:bg-gray-600 text-gray-300 p-1.5 rounded opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
style="background-color: rgba(0, 0, 0, 0.3)"
|
||||
data-target="modalErrorType"
|
||||
title="复制错误类型"
|
||||
>
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="p-4 rounded-lg relative group"
|
||||
style="background-color: rgba(80, 60, 160, 0.3)"
|
||||
>
|
||||
<h6 class="text-sm font-semibold text-violet-200 mb-1">错误日志:</h6>
|
||||
<pre
|
||||
id="modalErrorLog"
|
||||
class="font-mono text-sm p-3 rounded overflow-x-auto whitespace-pre-wrap"
|
||||
style="background-color: rgba(0, 0, 0, 0.2); color: #e5e7eb"
|
||||
></pre>
|
||||
<button
|
||||
class="copy-btn absolute top-2 right-2 hover:bg-gray-600 text-gray-300 p-1.5 rounded opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
style="background-color: rgba(0, 0, 0, 0.3)"
|
||||
data-target="modalErrorLog"
|
||||
title="复制错误日志"
|
||||
>
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="p-4 rounded-lg relative group"
|
||||
style="background-color: rgba(80, 60, 160, 0.3)"
|
||||
>
|
||||
<h6 class="text-sm font-semibold text-violet-200 mb-1">请求消息:</h6>
|
||||
<pre
|
||||
id="modalRequestMsg"
|
||||
class="font-mono text-sm p-3 rounded overflow-x-auto whitespace-pre-wrap"
|
||||
style="background-color: rgba(0, 0, 0, 0.2); color: #e5e7eb"
|
||||
></pre>
|
||||
<button
|
||||
class="copy-btn absolute top-2 right-2 hover:bg-gray-600 text-gray-300 p-1.5 rounded opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
style="background-color: rgba(0, 0, 0, 0.3)"
|
||||
data-target="modalRequestMsg"
|
||||
title="复制请求消息"
|
||||
>
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="p-4 rounded-lg relative group"
|
||||
style="background-color: rgba(80, 60, 160, 0.3)"
|
||||
>
|
||||
<h6 class="text-sm font-semibold text-violet-200 mb-1">模型名称:</h6>
|
||||
<p id="modalModelName" class="font-medium pr-8 text-gray-200"></p>
|
||||
<button
|
||||
class="copy-btn absolute top-2 right-2 hover:bg-gray-600 text-gray-300 p-1.5 rounded opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
style="background-color: rgba(0, 0, 0, 0.3)"
|
||||
data-target="modalModelName"
|
||||
title="复制模型名称"
|
||||
>
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="p-4 rounded-lg relative group"
|
||||
style="background-color: rgba(80, 60, 160, 0.3)"
|
||||
>
|
||||
<h6 class="text-sm font-semibold text-violet-200 mb-1">请求时间:</h6>
|
||||
<p id="modalRequestTime" class="font-medium pr-8 text-gray-200"></p>
|
||||
<button
|
||||
class="copy-btn absolute top-2 right-2 hover:bg-gray-600 text-gray-300 p-1.5 rounded opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
style="background-color: rgba(0, 0, 0, 0.3)"
|
||||
data-target="modalRequestTime"
|
||||
title="复制请求时间"
|
||||
>
|
||||
<i class="far fa-copy"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="flex justify-end mt-6 pt-4"
|
||||
style="border-top: 1px solid rgba(120, 100, 200, 0.4)"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
id="closeModalFooterBtn"
|
||||
class="bg-gray-500 bg-opacity-50 hover:bg-opacity-70 text-gray-200 px-6 py-2 rounded-lg font-medium transition"
|
||||
>
|
||||
关闭
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 删除确认模态框 -->
|
||||
<div id="deleteConfirmModal" class="modal">
|
||||
<div
|
||||
class="w-full max-w-md mx-auto rounded-xl shadow-xl overflow-hidden animate-fade-in"
|
||||
style="
|
||||
background-color: rgba(70, 50, 150, 0.95);
|
||||
color: #ffffff;
|
||||
border: 1px solid rgba(120, 100, 200, 0.4);
|
||||
"
|
||||
>
|
||||
<div class="p-6">
|
||||
<div
|
||||
class="flex justify-between items-center pb-3 mb-4"
|
||||
style="border-bottom: 1px solid rgba(120, 100, 200, 0.4)"
|
||||
>
|
||||
<h2 class="text-lg font-semibold text-gray-100">确认删除</h2>
|
||||
<button
|
||||
id="closeDeleteConfirmModalBtn"
|
||||
class="text-gray-300 hover:text-gray-100 text-xl"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
<p id="deleteConfirmMessage" class="text-gray-300 mb-6">
|
||||
你确定要删除选中的项目吗?此操作不可恢复!
|
||||
</p>
|
||||
<div class="flex justify-end gap-3">
|
||||
<button
|
||||
id="cancelDeleteBtn"
|
||||
type="button"
|
||||
class="bg-gray-500 bg-opacity-50 hover:bg-opacity-70 text-gray-200 px-5 py-2 rounded-lg font-medium transition"
|
||||
>
|
||||
取消
|
||||
</button>
|
||||
<button
|
||||
id="confirmDeleteBtn"
|
||||
type="button"
|
||||
class="bg-red-600 hover:bg-red-700 text-white px-5 py-2 rounded-lg font-medium transition"
|
||||
>
|
||||
确认删除
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %} {% block body_scripts %}
|
||||
<script src="/static/js/error_logs.js"></script>
|
||||
<script>
|
||||
// error_logs.html specific JS initialization (if any)
|
||||
// e.g., initialize date pickers or other elements if needed
|
||||
// The main logic is in error_logs.js
|
||||
</script>
|
||||
{% endblock %}
|
||||
3
app/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
工具包初始化模块
|
||||
"""
|
||||
176
app/utils/helpers.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
通用工具函数模块
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import logging # Import logging
|
||||
|
||||
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
|
||||
|
||||
# Define logger for helper functions if needed, or use specific loggers
|
||||
helper_logger = logging.getLogger("app.utils") # Or use a more specific logger if available
|
||||
|
||||
# Define project root and version file path here for get_current_version
|
||||
# Assuming this file is at app/utils/helpers.py
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"
|
||||
|
||||
|
||||
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据
|
||||
|
||||
Args:
|
||||
base64_string: 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
Returns:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith('data:'):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = DATA_URL_PATTERN
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
|
||||
def convert_image_to_base64(url: str) -> str:
|
||||
"""
|
||||
将图片URL转换为base64编码
|
||||
|
||||
Args:
|
||||
url: 图片URL
|
||||
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
|
||||
Raises:
|
||||
Exception: 如果获取图片失败
|
||||
"""
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
|
||||
|
||||
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
|
||||
"""
|
||||
格式化JSON响应
|
||||
|
||||
Args:
|
||||
data: 要格式化的数据
|
||||
indent: 缩进空格数
|
||||
|
||||
Returns:
|
||||
str: 格式化后的JSON字符串
|
||||
"""
|
||||
return json.dumps(data, indent=indent, ensure_ascii=False)
|
||||
|
||||
|
||||
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
|
||||
"""
|
||||
从prompt中解析参数
|
||||
|
||||
支持的格式:
|
||||
- {n:数量} 例如: {n:2} 生成2张图片
|
||||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
default_ratio: 默认比例
|
||||
|
||||
Returns:
|
||||
tuple: (清理后的提示文本, 图片数量, 比例)
|
||||
"""
|
||||
# 默认值
|
||||
n = 1
|
||||
aspect_ratio = default_ratio
|
||||
|
||||
# 解析n参数
|
||||
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||
if n_match:
|
||||
n = int(n_match.group(1))
|
||||
if n < 1 or n > 4:
|
||||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||
if ratio_match:
|
||||
aspect_ratio = ratio_match.group(1)
|
||||
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||
raise ValueError(
|
||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||
)
|
||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||
|
||||
return prompt, n, aspect_ratio
|
||||
|
||||
|
||||
def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||
"""
|
||||
从Markdown文本中提取图片URL
|
||||
|
||||
Args:
|
||||
text: Markdown文本
|
||||
|
||||
Returns:
|
||||
List[str]: 图片URL列表
|
||||
"""
|
||||
pattern = IMAGE_URL_PATTERN
|
||||
matches = re.findall(pattern, text)
|
||||
return [match[1] for match in matches]
|
||||
|
||||
|
||||
def is_valid_api_key(key: str) -> bool:
|
||||
"""
|
||||
检查API密钥格式是否有效
|
||||
|
||||
Args:
|
||||
key: API密钥
|
||||
|
||||
Returns:
|
||||
bool: 如果密钥格式有效则返回True
|
||||
"""
|
||||
# 检查Gemini API密钥格式
|
||||
if key.startswith('AIza'):
|
||||
return len(key) >= 30
|
||||
|
||||
# 检查OpenAI API密钥格式
|
||||
if key.startswith('sk-'):
|
||||
return len(key) >= 30
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_current_version(default_version: str = "0.0.0") -> str:
|
||||
"""Reads the current version from the VERSION file."""
|
||||
version_file = VERSION_FILE_PATH # Use Path object defined above
|
||||
try:
|
||||
# Use Path object's open method
|
||||
with version_file.open('r', encoding='utf-8') as f:
|
||||
version = f.read().strip()
|
||||
if not version:
|
||||
helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
return version
|
||||
except FileNotFoundError:
|
||||
helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
except IOError as e:
|
||||
helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
|
||||
return default_version
|
||||
@@ -1,5 +1,5 @@
|
||||
import requests
|
||||
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
|
||||
@@ -290,9 +290,9 @@ class CloudFlareImgBedUploader(ImageUploader):
|
||||
try:
|
||||
# 准备请求URL(添加认证码参数,如果存在)
|
||||
if self.auth_code:
|
||||
request_url = f"{self.api_url}?authCode={self.auth_code}"
|
||||
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
|
||||
else:
|
||||
request_url = self.api_url
|
||||
request_url = f"{self.api_url}?uploadNameType=origin"
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
@@ -1,9 +1,39 @@
|
||||
version: '3'
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
services:
|
||||
gemini-balance:
|
||||
build: .
|
||||
image: ghcr.io/snailyp/gemini-balance:latest
|
||||
container_name: gemini-balance
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "python -c \"import requests; exit(0) if requests.get('http://localhost:8000/health').status_code == 200 else exit(1)\""]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
mysql:
|
||||
image: mysql:8
|
||||
container_name: gemini-balance-mysql
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: your_root_password
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE}
|
||||
MYSQL_USER: ${MYSQL_USER}
|
||||
MYSQL_PASSWORD: ${MYSQL_PASSWORD}
|
||||
# ports:
|
||||
# - "3306:3306"
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "127.0.0.1"]
|
||||
interval: 10s # 每隔10秒检查一次
|
||||
timeout: 5s # 每次检查的超时时间为5秒
|
||||
retries: 3 # 重试3次失败后标记为 unhealthy
|
||||
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
|
||||
BIN
files/image.png
Normal file
|
After Width: | Height: | Size: 347 KiB |
BIN
files/image1.png
Normal file
|
After Width: | Height: | Size: 281 KiB |
BIN
files/image2.png
Normal file
|
After Width: | Height: | Size: 328 KiB |
BIN
files/image3.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
files/image4.png
Normal file
|
After Width: | Height: | Size: 459 KiB |
BIN
files/image5.png
Normal file
|
After Width: | Height: | Size: 292 KiB |
BIN
files/image6.png
Normal file
|
After Width: | Height: | Size: 163 KiB |
BIN
files/image7.png
Normal file
|
After Width: | Height: | Size: 665 KiB |
BIN
files/image8.png
Normal file
|
After Width: | Height: | Size: 97 KiB |
@@ -1,5 +1,5 @@
|
||||
fastapi
|
||||
httpx
|
||||
httpx[socks]
|
||||
openai
|
||||
pydantic
|
||||
pydantic_settings
|
||||
@@ -9,3 +9,13 @@ uvicorn
|
||||
google-genai
|
||||
jinja2
|
||||
python-multipart
|
||||
cryptography # 支持 MySQL 8+ caching_sha2_password 验证
|
||||
# 数据库相关依赖
|
||||
pymysql
|
||||
sqlalchemy
|
||||
aiomysql
|
||||
aiosqlite
|
||||
databases
|
||||
python-dotenv
|
||||
apscheduler
|
||||
packaging
|
||||
|
||||