mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-19 15:19:30 +08:00
feat: Add Gemini API embeddings compatibility with embedContent and batchEmbedContents methods
This commit is contained in:
@@ -80,3 +80,36 @@ class ResetSelectedKeysRequest(BaseModel):
|
||||
|
||||
class VerifySelectedKeysRequest(BaseModel):
|
||||
keys: List[str]
|
||||
|
||||
|
||||
class GeminiEmbedContent(BaseModel):
|
||||
"""嵌入内容模型"""
|
||||
|
||||
parts: List[Dict[str, str]]
|
||||
|
||||
|
||||
class GeminiEmbedRequest(BaseModel):
|
||||
"""单一嵌入请求模型"""
|
||||
|
||||
content: GeminiEmbedContent
|
||||
taskType: Optional[
|
||||
Literal[
|
||||
"TASK_TYPE_UNSPECIFIED",
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
"CODE_RETRIEVAL_QUERY",
|
||||
]
|
||||
] = None
|
||||
title: Optional[str] = None
|
||||
outputDimensionality: Optional[int] = None
|
||||
|
||||
|
||||
class GeminiBatchEmbedRequest(BaseModel):
|
||||
"""批量嵌入请求模型"""
|
||||
|
||||
requests: List[GeminiEmbedRequest]
|
||||
|
||||
@@ -284,6 +284,10 @@ def get_vertex_express_logger():
|
||||
return Logger.setup_logger("vertex_express")
|
||||
|
||||
|
||||
def get_gemini_embedding_logger():
|
||||
return Logger.setup_logger("gemini_embedding")
|
||||
|
||||
|
||||
def setup_access_logging():
|
||||
"""
|
||||
Configure uvicorn access logging with API key redaction
|
||||
|
||||
@@ -5,8 +5,9 @@ import asyncio
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest, GeminiEmbedRequest, GeminiBatchEmbedRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.service.model.model_service import ModelService
|
||||
@@ -38,6 +39,11 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
return GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取Gemini嵌入服务实例"""
|
||||
return GeminiEmbeddingService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@router_v1beta.get("/models")
|
||||
async def list_models(
|
||||
@@ -210,6 +216,63 @@ async def count_tokens(
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
@router.post("/models/{model_name}:embedContent")
|
||||
@router_v1beta.post("/models/{model_name}:embedContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def embed_content(
|
||||
model_name: str,
|
||||
request: GeminiEmbedRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
|
||||
):
|
||||
"""处理 Gemini 单一嵌入请求"""
|
||||
operation_name = "gemini_embed_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Embedding content generation failed"):
|
||||
logger.info(f"Handling Gemini embedding request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await embedding_service.embed_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:batchEmbedContents")
|
||||
@router_v1beta.post("/models/{model_name}:batchEmbedContents")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def batch_embed_contents(
|
||||
model_name: str,
|
||||
request: GeminiBatchEmbedRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
|
||||
):
|
||||
"""处理 Gemini 批量嵌入请求"""
|
||||
operation_name = "gemini_batch_embed_contents"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Batch embedding content generation failed"):
|
||||
logger.info(f"Handling Gemini batch embedding request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await embedding_service.batch_embed_contents(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
|
||||
@@ -161,6 +161,80 @@ class GeminiApiClient(ApiClient):
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
async def embed_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
"""单一嵌入内容生成"""
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for embedding: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:embedContent?key={api_key}"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Embedding request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Embedding request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def batch_embed_contents(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
"""批量嵌入内容生成"""
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for batch embedding: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:batchEmbedContents?key={api_key}"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Batch embedding request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Batch embedding request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected batch embedding error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
|
||||
148
app/service/embedding/gemini_embedding_service.py
Normal file
148
app/service/embedding/gemini_embedding_service.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# app/service/embedding/gemini_embedding_service.py
|
||||
|
||||
import datetime
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.gemini_models import GeminiBatchEmbedRequest, GeminiEmbedRequest
|
||||
from app.log.logger import get_gemini_embedding_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
|
||||
logger = get_gemini_embedding_logger()
|
||||
|
||||
|
||||
def _build_embed_payload(request: GeminiEmbedRequest) -> Dict[str, Any]:
|
||||
"""构建嵌入请求payload"""
|
||||
payload = {"content": request.content.model_dump()}
|
||||
|
||||
if request.taskType:
|
||||
payload["taskType"] = request.taskType
|
||||
if request.title:
|
||||
payload["title"] = request.title
|
||||
if request.outputDimensionality:
|
||||
payload["outputDimensionality"] = request.outputDimensionality
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _build_batch_embed_payload(
|
||||
request: GeminiBatchEmbedRequest, model: str
|
||||
) -> Dict[str, Any]:
|
||||
"""构建批量嵌入请求payload"""
|
||||
requests = []
|
||||
for embed_request in request.requests:
|
||||
embed_payload = _build_embed_payload(embed_request)
|
||||
embed_payload["model"] = (
|
||||
f"models/{model}" # Gemini API要求每个请求包含model字段
|
||||
)
|
||||
requests.append(embed_payload)
|
||||
|
||||
return {"requests": requests}
|
||||
|
||||
|
||||
class GeminiEmbeddingService:
|
||||
"""Gemini嵌入服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager):
|
||||
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
||||
self.key_manager = key_manager
|
||||
|
||||
async def embed_content(
|
||||
self, model: str, request: GeminiEmbedRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成单一嵌入内容"""
|
||||
payload = _build_embed_payload(request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.embed_content(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Single embedding API call failed: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-embed-single",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def batch_embed_contents(
|
||||
self, model: str, request: GeminiBatchEmbedRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成批量嵌入内容"""
|
||||
payload = _build_batch_embed_payload(request, model)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.batch_embed_contents(
|
||||
payload, model, api_key
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Batch embedding API call failed: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-embed-batch",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
Reference in New Issue
Block a user