mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-07 06:23:02 +08:00
first commit
This commit is contained in:
259
.gitignore
vendored
Normal file
259
.gitignore
vendored
Normal file
@@ -0,0 +1,259 @@
|
||||
# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,circuitpython,python,pythonvanilla
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,visualstudiocode,circuitpython,python,pythonvanilla
|
||||
|
||||
### CircuitPython ###
|
||||
.Trashes
|
||||
.metadata_never_index
|
||||
.fseventsd/
|
||||
boot_out.txt
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
### PythonVanilla ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
||||
# C extensions
|
||||
|
||||
# Distribution / packaging
|
||||
|
||||
# Installer logs
|
||||
|
||||
# Unit test / coverage reports
|
||||
|
||||
# Translations
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
|
||||
|
||||
### VisualStudioCode ###
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
### VisualStudioCode Patch ###
|
||||
# Ignore all local history of files
|
||||
.history
|
||||
.ionide
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,circuitpython,python,pythonvanilla
|
||||
|
||||
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
|
||||
|
||||
tests/
|
||||
19
.vscode/launch.json
vendored
Normal file
19
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python 调试程序: FastAPI",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"main:app",
|
||||
"--reload"
|
||||
],
|
||||
"jinja": true
|
||||
}
|
||||
]
|
||||
}
|
||||
19
Dockerfile
Normal file
19
Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM python:3.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制所需文件到容器中
|
||||
COPY ./app /app/app
|
||||
COPY ./main.py /app
|
||||
COPY ./requirements.txt /app
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
ENV API_KEYS=["your_api_key_1"]
|
||||
ENV ALLOWED_TOKENS=["your_token_1"]
|
||||
ENV BASE_URL=https://api.groq.com/openai/v1
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
108
README.md
Normal file
108
README.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# 🚀 FastAPI OpenAI 代理服务
|
||||
|
||||
## 📝 项目简介
|
||||
|
||||
这是一个基于 FastAPI 框架开发的 OpenAI API 代理服务,支持多 API Key 轮询和流式响应。
|
||||
|
||||
## ✨ 主要特性
|
||||
|
||||
- 🔄 多 API Key 轮询支持
|
||||
- 🔐 Bearer Token 认证
|
||||
- 📡 支持流式响应
|
||||
- 🌐 CORS 跨域支持
|
||||
- 📊 健康检查接口
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
- FastAPI
|
||||
- OpenAI
|
||||
- Pydantic
|
||||
- Docker
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Python 3.9+
|
||||
- Docker (可选)
|
||||
|
||||
### 📦 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### ⚙️ 配置文件
|
||||
|
||||
创建 `.env` 文件并配置以下参数:
|
||||
|
||||
```env
|
||||
API_KEYS=["your-api-key-1","your-api-key-2"]
|
||||
ALLOWED_TOKENS=["your-access-token-1","your-access-token-2"]
|
||||
BASE_URL="https://api.openai.com/v1"
|
||||
```
|
||||
|
||||
### 🐳 Docker 部署
|
||||
|
||||
```bash
|
||||
docker build -t openai-comatible-balance .
|
||||
docker run -p 8000:8000 -d openai-comatible-balance
|
||||
```
|
||||
|
||||
## 🔌 API 接口
|
||||
|
||||
### 获取模型列表
|
||||
|
||||
```http
|
||||
GET /v1/models
|
||||
Authorization: Bearer your-token
|
||||
```
|
||||
|
||||
### 聊天完成
|
||||
|
||||
```http
|
||||
POST /v1/chat/completions
|
||||
Authorization: Bearer your-token
|
||||
|
||||
{
|
||||
"messages": [...],
|
||||
"model": "llama-3.2-90b-text-preview",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
### 健康检查
|
||||
|
||||
```http
|
||||
GET /health
|
||||
Authorization: Bearer your-token
|
||||
```
|
||||
|
||||
## 📚 代码结构
|
||||
|
||||
- `app/main.py`: 主应用程序入口
|
||||
- `app/config.py`: 配置管理
|
||||
- `Dockerfile`: 容器化配置
|
||||
- `requirements.txt`: 项目依赖
|
||||
|
||||
## 🔒 安全特性
|
||||
|
||||
- API Key 轮询机制
|
||||
- Bearer Token 认证
|
||||
- 请求日志记录
|
||||
|
||||
## 📝 注意事项
|
||||
|
||||
- 请确保妥善保管 API Keys 和访问令牌
|
||||
- 建议在生产环境中使用环境变量配置敏感信息
|
||||
- 默认服务端口为 8000
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
欢迎提交 Issue 和 Pull Request!
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
MIT License
|
||||
20
app/config.py
Normal file
20
app/config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
class Settings(BaseSettings):
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
BASE_URL: str
|
||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = True
|
||||
# 同时从环境变量和.env文件获取配置
|
||||
env_nested_delimiter = "__"
|
||||
extra = "ignore"
|
||||
|
||||
# 优先从环境变量获取,如果没有则从.env文件获取
|
||||
settings = Settings(_env_file=os.getenv("ENV_FILE", ".env"))
|
||||
334
main.py
Normal file
334
main.py
Normal file
@@ -0,0 +1,334 @@
|
||||
from fastapi import FastAPI, HTTPException, Header, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import openai
|
||||
from typing import List, Optional, Union
|
||||
import logging
|
||||
from itertools import cycle
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
|
||||
from app import config
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import httpx
|
||||
import uuid
|
||||
import time
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 允许跨域
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# API密钥配置
|
||||
API_KEYS = config.settings.API_KEYS
|
||||
|
||||
# 创建一个循环迭代器
|
||||
key_cycle = cycle(API_KEYS)
|
||||
key_lock = asyncio.Lock()
|
||||
|
||||
# 添加key失败计数记录
|
||||
key_failure_counts = {key: 0 for key in API_KEYS}
|
||||
MAX_FAILURES = 1 # 最大失败次数阈值
|
||||
|
||||
|
||||
async def get_next_working_key():
|
||||
"""获取下一个可用的API key"""
|
||||
async with key_lock:
|
||||
current_key = next(key_cycle)
|
||||
initial_key = current_key
|
||||
|
||||
while key_failure_counts[current_key] >= MAX_FAILURES:
|
||||
current_key = next(key_cycle)
|
||||
if current_key == initial_key: # 已经循环了一圈
|
||||
# 重置所有失败计数
|
||||
for key in key_failure_counts:
|
||||
key_failure_counts[key] = 0
|
||||
break
|
||||
|
||||
return current_key
|
||||
|
||||
|
||||
async def handle_api_failure(api_key):
|
||||
"""处理API调用失败"""
|
||||
async with key_lock:
|
||||
key_failure_counts[api_key] += 1
|
||||
if key_failure_counts[api_key] >= MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key"
|
||||
)
|
||||
return await get_next_working_key()
|
||||
return api_key
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str = "gemini-1.5-flash-002"
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
tool_choice: Optional[str] = "auto"
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-004"
|
||||
encoding_format: Optional[str] = "float"
|
||||
|
||||
|
||||
async def verify_authorization(authorization: str = Header(None)):
|
||||
if not authorization:
|
||||
logger.error("Missing Authorization header")
|
||||
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||
if not authorization.startswith("Bearer "):
|
||||
logger.error("Invalid Authorization header format")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid Authorization header format"
|
||||
)
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token not in config.settings.ALLOWED_TOKENS:
|
||||
logger.error("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return token
|
||||
|
||||
|
||||
def get_gemini_models(api_key):
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
url = f"{base_url}/models?key={api_key}"
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
return convert_to_openai_models_format(gemini_models)
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_openai_models_format(gemini_models):
|
||||
openai_format = {"object": "list", "data": []}
|
||||
|
||||
for model in gemini_models.get("models", []):
|
||||
openai_model = {
|
||||
"id": model["name"].split("/")[-1], # 取最后一部分作为ID
|
||||
"object": "model",
|
||||
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
|
||||
"owned_by": "google", # 假设所有Gemini模型都由Google拥有
|
||||
"permission": [], # Gemini API可能没有直接对应的权限信息
|
||||
"root": model["name"],
|
||||
"parent": None, # Gemini API可能没有直接对应的父模型信息
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
return openai_format
|
||||
|
||||
|
||||
def convert_messages_to_gemini_format(messages):
|
||||
"""Convert OpenAI message format to Gemini format"""
|
||||
gemini_messages = []
|
||||
for message in messages:
|
||||
gemini_message = {
|
||||
"role": "user" if message["role"] == "user" else "model",
|
||||
"parts": [{"text": message["content"]}],
|
||||
}
|
||||
gemini_messages.append(gemini_message)
|
||||
return gemini_messages
|
||||
|
||||
|
||||
def convert_gemini_response_to_openai(response, model, stream=False):
|
||||
"""Convert Gemini response to OpenAI format"""
|
||||
if stream:
|
||||
# 处理流式响应
|
||||
chunk = response
|
||||
if not chunk["candidates"]:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["candidates"][0]["content"]["parts"][0]["text"]
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
# 处理普通响应
|
||||
return {
|
||||
"id": "chatcmpl-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0][
|
||||
"text"
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
@app.get("/hf/v1/models")
|
||||
async def list_models(authorization: str = Header(None)):
|
||||
await verify_authorization(authorization)
|
||||
async with key_lock:
|
||||
api_key = next(key_cycle)
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
response = get_gemini_models(api_key)
|
||||
logger.info("Successfully retrieved models list")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing models: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/hf/v1/chat/completions")
|
||||
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
|
||||
await verify_authorization(authorization)
|
||||
api_key = await get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
try:
|
||||
logger.info(f"Chat completion request - Model: {request.model}")
|
||||
if request.model in config.settings.MODEL_SEARCH:
|
||||
# 转换消息格式
|
||||
gemini_messages = convert_messages_to_gemini_format(request.messages)
|
||||
|
||||
# 调用Gemini API
|
||||
non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
payload = {
|
||||
"contents": gemini_messages,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
},
|
||||
"tools": [{"googleSearch": {}}],
|
||||
}
|
||||
|
||||
if request.stream:
|
||||
logger.info("Streaming response enabled")
|
||||
|
||||
async def generate():
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = (
|
||||
convert_gemini_response_to_openai(
|
||||
chunk, request.model, stream=True
|
||||
)
|
||||
)
|
||||
if openai_chunk:
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
content=generate(), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(non_stream_url, json=payload)
|
||||
gemini_response = response.json()
|
||||
openai_response = convert_gemini_response_to_openai(
|
||||
gemini_response, request.model
|
||||
)
|
||||
|
||||
logger.info("Chat completion successful")
|
||||
return openai_response
|
||||
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
|
||||
response = client.chat.completions.create(
|
||||
model=request.model,
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
stream=request.stream if hasattr(request, "stream") else False,
|
||||
)
|
||||
|
||||
if hasattr(request, "stream") and request.stream:
|
||||
logger.info("Streaming response enabled")
|
||||
|
||||
async def generate():
|
||||
for chunk in response:
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
return StreamingResponse(content=generate(), media_type="text/event-stream")
|
||||
|
||||
logger.info("Chat completion successful")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat completion: {str(e)}")
|
||||
api_key = await handle_api_failure(api_key) # 处理失败并可能切换key
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
@app.post("/hf/v1/embeddings")
|
||||
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
|
||||
await verify_authorization(authorization)
|
||||
async with key_lock:
|
||||
api_key = next(key_cycle)
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
|
||||
response = client.embeddings.create(input=request.input, model=request.model)
|
||||
logger.info("Embedding successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embedding: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/")
|
||||
async def health_check():
|
||||
logger.info("Health check endpoint called")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
fastapi
|
||||
openai
|
||||
pydantic
|
||||
pydantic_settings
|
||||
uvicorn
|
||||
Reference in New Issue
Block a user