From 544f79e7901704a8a6550d921a67eac90c8b0c05 Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Thu, 12 Dec 2024 13:31:55 +0800 Subject: [PATCH] first commit --- .gitignore | 259 ++++++++++++++++++++++++++++++++++ .vscode/launch.json | 19 +++ Dockerfile | 19 +++ README.md | 108 ++++++++++++++ app/config.py | 20 +++ main.py | 334 ++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 5 + 7 files changed, 764 insertions(+) create mode 100644 .gitignore create mode 100644 .vscode/launch.json create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 app/config.py create mode 100644 main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..89b4e48 --- /dev/null +++ b/.gitignore @@ -0,0 +1,259 @@ +# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig +# Created by https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,circuitpython,python,pythonvanilla +# Edit at https://www.toptal.com/developers/gitignore?templates=windows,visualstudiocode,circuitpython,python,pythonvanilla + +### CircuitPython ### +.Trashes +.metadata_never_index +.fseventsd/ +boot_out.txt + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### PythonVanilla ### +# Byte-compiled / optimized / DLL files + +# C extensions + +# Distribution / packaging + +# Installer logs + +# Unit test / coverage reports + +# Translations + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow + + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,circuitpython,python,pythonvanilla + +# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) + +tests/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..0eca816 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,19 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python 调试程序: FastAPI", + "type": "debugpy", + "request": "launch", + "module": "uvicorn", + "args": [ + "main:app", + "--reload" + ], + "jinja": true + } + ] +} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e96b6e4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.9-slim + +WORKDIR /app + +# 复制所需文件到容器中 +COPY ./app /app/app +COPY ./main.py /app +COPY ./requirements.txt /app + +RUN pip install --no-cache-dir -r requirements.txt +ENV API_KEYS=["your_api_key_1"] +ENV ALLOWED_TOKENS=["your_token_1"] +ENV BASE_URL=https://api.groq.com/openai/v1 + +# Expose port +EXPOSE 8000 + +# Run the application +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..7d28720 --- /dev/null +++ b/README.md @@ -0,0 +1,108 @@ +# 🚀 FastAPI OpenAI 代理服务 + +## 📝 项目简介 + +这是一个基于 FastAPI 框架开发的 OpenAI API 代理服务,支持多 API Key 轮询和流式响应。 + +## ✨ 主要特性 + +- 🔄 多 API Key 轮询支持 +- 🔐 Bearer Token 认证 +- 📡 支持流式响应 +- 🌐 CORS 跨域支持 +- 📊 健康检查接口 + +## 🛠️ 技术栈 + +- FastAPI +- OpenAI +- Pydantic +- Docker + +## 🚀 快速开始 + +### 环境要求 + +- Python 3.9+ +- Docker (可选) + +### 📦 安装依赖 + +```bash +pip install -r requirements.txt +``` + +### ⚙️ 配置文件 + +创建 `.env` 文件并配置以下参数: + +```env +API_KEYS=["your-api-key-1","your-api-key-2"] +ALLOWED_TOKENS=["your-access-token-1","your-access-token-2"] +BASE_URL="https://api.openai.com/v1" +``` + +### 🐳 Docker 部署 + +```bash +docker build -t openai-comatible-balance . +docker run -p 8000:8000 -d openai-comatible-balance +``` + +## 🔌 API 接口 + +### 获取模型列表 + +```http +GET /v1/models +Authorization: Bearer your-token +``` + +### 聊天完成 + +```http +POST /v1/chat/completions +Authorization: Bearer your-token + +{ + "messages": [...], + "model": "llama-3.2-90b-text-preview", + "temperature": 0.7, + "max_tokens": 1000, + "stream": false +} +``` + +### 健康检查 + +```http +GET /health +Authorization: Bearer your-token +``` + +## 📚 代码结构 + +- `app/main.py`: 主应用程序入口 +- `app/config.py`: 配置管理 +- `Dockerfile`: 容器化配置 +- `requirements.txt`: 项目依赖 + +## 🔒 安全特性 + +- API Key 轮询机制 +- Bearer Token 认证 +- 请求日志记录 + +## 📝 注意事项 + +- 请确保妥善保管 API Keys 和访问令牌 +- 建议在生产环境中使用环境变量配置敏感信息 +- 默认服务端口为 8000 + +## 🤝 贡献 + +欢迎提交 Issue 和 Pull Request! + +## 📄 许可证 + +MIT License diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..758f55b --- /dev/null +++ b/app/config.py @@ -0,0 +1,20 @@ +from pydantic_settings import BaseSettings +import os +from typing import List + +class Settings(BaseSettings): + API_KEYS: List[str] + ALLOWED_TOKENS: List[str] + BASE_URL: str + MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = True + # 同时从环境变量和.env文件获取配置 + env_nested_delimiter = "__" + extra = "ignore" + +# 优先从环境变量获取,如果没有则从.env文件获取 +settings = Settings(_env_file=os.getenv("ENV_FILE", ".env")) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..9197d96 --- /dev/null +++ b/main.py @@ -0,0 +1,334 @@ +from fastapi import FastAPI, HTTPException, Header, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +import openai +from typing import List, Optional, Union +import logging +from itertools import cycle +import asyncio + +import uvicorn + +from app import config +import requests +from datetime import datetime, timezone +import json +import httpx +import uuid +import time + +# 配置日志 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +app = FastAPI() + +# 允许跨域 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# API密钥配置 +API_KEYS = config.settings.API_KEYS + +# 创建一个循环迭代器 +key_cycle = cycle(API_KEYS) +key_lock = asyncio.Lock() + +# 添加key失败计数记录 +key_failure_counts = {key: 0 for key in API_KEYS} +MAX_FAILURES = 1 # 最大失败次数阈值 + + +async def get_next_working_key(): + """获取下一个可用的API key""" + async with key_lock: + current_key = next(key_cycle) + initial_key = current_key + + while key_failure_counts[current_key] >= MAX_FAILURES: + current_key = next(key_cycle) + if current_key == initial_key: # 已经循环了一圈 + # 重置所有失败计数 + for key in key_failure_counts: + key_failure_counts[key] = 0 + break + + return current_key + + +async def handle_api_failure(api_key): + """处理API调用失败""" + async with key_lock: + key_failure_counts[api_key] += 1 + if key_failure_counts[api_key] >= MAX_FAILURES: + logger.warning( + f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key" + ) + return await get_next_working_key() + return api_key + + +class ChatRequest(BaseModel): + messages: List[dict] + model: str = "gemini-1.5-flash-002" + temperature: Optional[float] = 0.7 + stream: Optional[bool] = False + tools: Optional[List[dict]] = [] + tool_choice: Optional[str] = "auto" + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str = "text-embedding-004" + encoding_format: Optional[str] = "float" + + +async def verify_authorization(authorization: str = Header(None)): + if not authorization: + logger.error("Missing Authorization header") + raise HTTPException(status_code=401, detail="Missing Authorization header") + if not authorization.startswith("Bearer "): + logger.error("Invalid Authorization header format") + raise HTTPException( + status_code=401, detail="Invalid Authorization header format" + ) + token = authorization.replace("Bearer ", "") + if token not in config.settings.ALLOWED_TOKENS: + logger.error("Invalid token") + raise HTTPException(status_code=401, detail="Invalid token") + return token + + +def get_gemini_models(api_key): + base_url = "https://generativelanguage.googleapis.com/v1beta" + url = f"{base_url}/models?key={api_key}" + + try: + response = requests.get(url) + if response.status_code == 200: + gemini_models = response.json() + return convert_to_openai_models_format(gemini_models) + else: + print(f"Error: {response.status_code}") + print(response.text) + return None + + except requests.RequestException as e: + print(f"Request failed: {e}") + return None + + +def convert_to_openai_models_format(gemini_models): + openai_format = {"object": "list", "data": []} + + for model in gemini_models.get("models", []): + openai_model = { + "id": model["name"].split("/")[-1], # 取最后一部分作为ID + "object": "model", + "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳 + "owned_by": "google", # 假设所有Gemini模型都由Google拥有 + "permission": [], # Gemini API可能没有直接对应的权限信息 + "root": model["name"], + "parent": None, # Gemini API可能没有直接对应的父模型信息 + } + openai_format["data"].append(openai_model) + + return openai_format + + +def convert_messages_to_gemini_format(messages): + """Convert OpenAI message format to Gemini format""" + gemini_messages = [] + for message in messages: + gemini_message = { + "role": "user" if message["role"] == "user" else "model", + "parts": [{"text": message["content"]}], + } + gemini_messages.append(gemini_message) + return gemini_messages + + +def convert_gemini_response_to_openai(response, model, stream=False): + """Convert Gemini response to OpenAI format""" + if stream: + # 处理流式响应 + chunk = response + if not chunk["candidates"]: + return None + + return { + "id": "chatcmpl-" + str(uuid.uuid4()), + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": chunk["candidates"][0]["content"]["parts"][0]["text"] + }, + "finish_reason": None, + } + ], + } + else: + # 处理普通响应 + return { + "id": "chatcmpl-" + str(uuid.uuid4()), + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response["candidates"][0]["content"]["parts"][0][ + "text" + ], + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +@app.get("/v1/models") +@app.get("/hf/v1/models") +async def list_models(authorization: str = Header(None)): + await verify_authorization(authorization) + async with key_lock: + api_key = next(key_cycle) + logger.info(f"Using API key: {api_key}") + try: + response = get_gemini_models(api_key) + logger.info("Successfully retrieved models list") + return response + except Exception as e: + logger.error(f"Error listing models: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/v1/chat/completions") +@app.post("/hf/v1/chat/completions") +async def chat_completion(request: ChatRequest, authorization: str = Header(None)): + await verify_authorization(authorization) + api_key = await get_next_working_key() + logger.info(f"Using API key: {api_key}") + + try: + logger.info(f"Chat completion request - Model: {request.model}") + if request.model in config.settings.MODEL_SEARCH: + # 转换消息格式 + gemini_messages = convert_messages_to_gemini_format(request.messages) + + # 调用Gemini API + non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}" + stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}" + payload = { + "contents": gemini_messages, + "generationConfig": { + "temperature": request.temperature, + }, + "tools": [{"googleSearch": {}}], + } + + if request.stream: + logger.info("Streaming response enabled") + + async def generate(): + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", stream_url, json=payload + ) as response: + async for line in response.aiter_lines(): + if line.startswith("data: "): + try: + chunk = json.loads(line[6:]) + openai_chunk = ( + convert_gemini_response_to_openai( + chunk, request.model, stream=True + ) + ) + if openai_chunk: + yield f"data: {json.dumps(openai_chunk)}\n\n" + except json.JSONDecodeError: + continue + yield "data: [DONE]\n\n" + + return StreamingResponse( + content=generate(), media_type="text/event-stream" + ) + else: + # 非流式响应 + async with httpx.AsyncClient() as client: + response = await client.post(non_stream_url, json=payload) + gemini_response = response.json() + openai_response = convert_gemini_response_to_openai( + gemini_response, request.model + ) + + logger.info("Chat completion successful") + return openai_response + client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) + response = client.chat.completions.create( + model=request.model, + messages=request.messages, + temperature=request.temperature, + stream=request.stream if hasattr(request, "stream") else False, + ) + + if hasattr(request, "stream") and request.stream: + logger.info("Streaming response enabled") + + async def generate(): + for chunk in response: + yield f"data: {chunk.model_dump_json()}\n\n" + + return StreamingResponse(content=generate(), media_type="text/event-stream") + + logger.info("Chat completion successful") + return response + + except Exception as e: + logger.error(f"Error in chat completion: {str(e)}") + api_key = await handle_api_failure(api_key) # 处理失败并可能切换key + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/v1/embeddings") +@app.post("/hf/v1/embeddings") +async def embedding(request: EmbeddingRequest, authorization: str = Header(None)): + await verify_authorization(authorization) + async with key_lock: + api_key = next(key_cycle) + logger.info(f"Using API key: {api_key}") + + try: + client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) + response = client.embeddings.create(input=request.input, model=request.model) + logger.info("Embedding successful") + return response + except Exception as e: + logger.error(f"Error in embedding: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +@app.get("/") +async def health_check(): + logger.info("Health check endpoint called") + return {"status": "healthy"} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bfbd29c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +fastapi +openai +pydantic +pydantic_settings +uvicorn \ No newline at end of file