feat: Add Gemini API embeddings compatibility with embedContent and batchEmbedContents methods

This commit is contained in:
cxyfer
2025-07-30 02:26:12 +08:00
parent a6558b4668
commit b89d3ea144
5 changed files with 323 additions and 1 deletions

View File

@@ -80,3 +80,36 @@ class ResetSelectedKeysRequest(BaseModel):
class VerifySelectedKeysRequest(BaseModel):
keys: List[str]
class GeminiEmbedContent(BaseModel):
"""嵌入内容模型"""
parts: List[Dict[str, str]]
class GeminiEmbedRequest(BaseModel):
"""单一嵌入请求模型"""
content: GeminiEmbedContent
taskType: Optional[
Literal[
"TASK_TYPE_UNSPECIFIED",
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
"CODE_RETRIEVAL_QUERY",
]
] = None
title: Optional[str] = None
outputDimensionality: Optional[int] = None
class GeminiBatchEmbedRequest(BaseModel):
"""批量嵌入请求模型"""
requests: List[GeminiEmbedRequest]

View File

@@ -284,6 +284,10 @@ def get_vertex_express_logger():
return Logger.setup_logger("vertex_express")
def get_gemini_embedding_logger():
return Logger.setup_logger("gemini_embedding")
def setup_access_logging():
"""
Configure uvicorn access logging with API key redaction

View File

@@ -5,8 +5,9 @@ import asyncio
from app.config.config import settings
from app.log.logger import get_gemini_logger
from app.core.security import SecurityService
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest, GeminiEmbedRequest, GeminiBatchEmbedRequest
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.tts.native.tts_routes import get_tts_chat_service
from app.service.model.model_service import ModelService
@@ -38,6 +39,11 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
return GeminiChatService(settings.BASE_URL, key_manager)
async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manager)):
"""获取Gemini嵌入服务实例"""
return GeminiEmbeddingService(settings.BASE_URL, key_manager)
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(
@@ -210,6 +216,63 @@ async def count_tokens(
api_key=api_key
)
return response
@router.post("/models/{model_name}:embedContent")
@router_v1beta.post("/models/{model_name}:embedContent")
@RetryHandler(key_arg="api_key")
async def embed_content(
model_name: str,
request: GeminiEmbedRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
):
"""处理 Gemini 单一嵌入请求"""
operation_name = "gemini_embed_content"
async with handle_route_errors(logger, operation_name, failure_message="Embedding content generation failed"):
logger.info(f"Handling Gemini embedding request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response = await embedding_service.embed_content(
model=model_name,
request=request,
api_key=api_key
)
return response
@router.post("/models/{model_name}:batchEmbedContents")
@router_v1beta.post("/models/{model_name}:batchEmbedContents")
@RetryHandler(key_arg="api_key")
async def batch_embed_contents(
model_name: str,
request: GeminiBatchEmbedRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service)
):
"""处理 Gemini 批量嵌入请求"""
operation_name = "gemini_batch_embed_contents"
async with handle_route_errors(logger, operation_name, failure_message="Batch embedding content generation failed"):
logger.info(f"Handling Gemini batch embedding request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response = await embedding_service.batch_embed_contents(
model=model_name,
request=request,
api_key=api_key
)
return response
@router.post("/reset-all-fail-counts")

View File

@@ -161,6 +161,80 @@ class GeminiApiClient(ApiClient):
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def embed_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
"""单一嵌入内容生成"""
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
else:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for embedding: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:embedContent?key={api_key}"
try:
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}")
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
except httpx.TimeoutException as e:
logger.error(f"Embedding request timeout: {e}")
raise Exception(f"Request timeout: {e}")
except httpx.RequestError as e:
logger.error(f"Embedding request error: {e}")
raise Exception(f"Request error: {e}")
except Exception as e:
logger.error(f"Unexpected embedding error: {e}")
raise
async def batch_embed_contents(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
"""批量嵌入内容生成"""
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
else:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for batch embedding: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:batchEmbedContents?key={api_key}"
try:
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}")
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
except httpx.TimeoutException as e:
logger.error(f"Batch embedding request timeout: {e}")
raise Exception(f"Request timeout: {e}")
except httpx.RequestError as e:
logger.error(f"Batch embedding request error: {e}")
raise Exception(f"Request error: {e}")
except Exception as e:
logger.error(f"Unexpected batch embedding error: {e}")
raise
class OpenaiApiClient(ApiClient):
"""OpenAI API客户端"""

View File

@@ -0,0 +1,148 @@
# app/service/embedding/gemini_embedding_service.py
import datetime
import re
import time
from typing import Any, Dict
from app.config.config import settings
from app.database.services import add_error_log, add_request_log
from app.domain.gemini_models import GeminiBatchEmbedRequest, GeminiEmbedRequest
from app.log.logger import get_gemini_embedding_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
logger = get_gemini_embedding_logger()
def _build_embed_payload(request: GeminiEmbedRequest) -> Dict[str, Any]:
"""构建嵌入请求payload"""
payload = {"content": request.content.model_dump()}
if request.taskType:
payload["taskType"] = request.taskType
if request.title:
payload["title"] = request.title
if request.outputDimensionality:
payload["outputDimensionality"] = request.outputDimensionality
return payload
def _build_batch_embed_payload(
request: GeminiBatchEmbedRequest, model: str
) -> Dict[str, Any]:
"""构建批量嵌入请求payload"""
requests = []
for embed_request in request.requests:
embed_payload = _build_embed_payload(embed_request)
embed_payload["model"] = (
f"models/{model}" # Gemini API要求每个请求包含model字段
)
requests.append(embed_payload)
return {"requests": requests}
class GeminiEmbeddingService:
"""Gemini嵌入服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
self.key_manager = key_manager
async def embed_content(
self, model: str, request: GeminiEmbedRequest, api_key: str
) -> Dict[str, Any]:
"""生成单一嵌入内容"""
payload = _build_embed_payload(request)
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.embed_content(payload, model, api_key)
is_success = True
status_code = 200
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Single embedding API call failed: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="gemini-embed-single",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
async def batch_embed_contents(
self, model: str, request: GeminiBatchEmbedRequest, api_key: str
) -> Dict[str, Any]:
"""生成批量嵌入内容"""
payload = _build_batch_embed_payload(request, model)
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.batch_embed_contents(
payload, model, api_key
)
is_success = True
status_code = 200
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Batch embedding API call failed: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="gemini-embed-batch",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)