diff --git a/app/domain/gemini_models.py b/app/domain/gemini_models.py index bd96e4c..0a96c3c 100644 --- a/app/domain/gemini_models.py +++ b/app/domain/gemini_models.py @@ -80,3 +80,36 @@ class ResetSelectedKeysRequest(BaseModel): class VerifySelectedKeysRequest(BaseModel): keys: List[str] + + +class GeminiEmbedContent(BaseModel): + """嵌入内容模型""" + + parts: List[Dict[str, str]] + + +class GeminiEmbedRequest(BaseModel): + """单一嵌入请求模型""" + + content: GeminiEmbedContent + taskType: Optional[ + Literal[ + "TASK_TYPE_UNSPECIFIED", + "RETRIEVAL_QUERY", + "RETRIEVAL_DOCUMENT", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING", + "QUESTION_ANSWERING", + "FACT_VERIFICATION", + "CODE_RETRIEVAL_QUERY", + ] + ] = None + title: Optional[str] = None + outputDimensionality: Optional[int] = None + + +class GeminiBatchEmbedRequest(BaseModel): + """批量嵌入请求模型""" + + requests: List[GeminiEmbedRequest] diff --git a/app/log/logger.py b/app/log/logger.py index c24bfcc..0b8c0d2 100644 --- a/app/log/logger.py +++ b/app/log/logger.py @@ -284,6 +284,10 @@ def get_vertex_express_logger(): return Logger.setup_logger("vertex_express") +def get_gemini_embedding_logger(): + return Logger.setup_logger("gemini_embedding") + + def setup_access_logging(): """ Configure uvicorn access logging with API key redaction diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index 8e9d6c6..33832ac 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -5,8 +5,9 @@ import asyncio from app.config.config import settings from app.log.logger import get_gemini_logger from app.core.security import SecurityService -from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest +from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest, GeminiEmbedRequest, GeminiBatchEmbedRequest from app.service.chat.gemini_chat_service import GeminiChatService +from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService from app.service.key.key_manager import KeyManager, get_key_manager_instance from app.service.tts.native.tts_routes import get_tts_chat_service from app.service.model.model_service import ModelService @@ -38,6 +39,11 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)): return GeminiChatService(settings.BASE_URL, key_manager) +async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取Gemini嵌入服务实例""" + return GeminiEmbeddingService(settings.BASE_URL, key_manager) + + @router.get("/models") @router_v1beta.get("/models") async def list_models( @@ -210,6 +216,63 @@ async def count_tokens( api_key=api_key ) return response + +@router.post("/models/{model_name}:embedContent") +@router_v1beta.post("/models/{model_name}:embedContent") +@RetryHandler(key_arg="api_key") +async def embed_content( + model_name: str, + request: GeminiEmbedRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + embedding_service: GeminiEmbeddingService = Depends(get_embedding_service) +): + """处理 Gemini 单一嵌入请求""" + operation_name = "gemini_embed_content" + async with handle_route_errors(logger, operation_name, failure_message="Embedding content generation failed"): + logger.info(f"Handling Gemini embedding request for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {redact_key_for_logging(api_key)}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response = await embedding_service.embed_content( + model=model_name, + request=request, + api_key=api_key + ) + return response + + +@router.post("/models/{model_name}:batchEmbedContents") +@router_v1beta.post("/models/{model_name}:batchEmbedContents") +@RetryHandler(key_arg="api_key") +async def batch_embed_contents( + model_name: str, + request: GeminiBatchEmbedRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + embedding_service: GeminiEmbeddingService = Depends(get_embedding_service) +): + """处理 Gemini 批量嵌入请求""" + operation_name = "gemini_batch_embed_contents" + async with handle_route_errors(logger, operation_name, failure_message="Batch embedding content generation failed"): + logger.info(f"Handling Gemini batch embedding request for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {redact_key_for_logging(api_key)}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response = await embedding_service.batch_embed_contents( + model=model_name, + request=request, + api_key=api_key + ) + return response @router.post("/reset-all-fail-counts") diff --git a/app/service/client/api_client.py b/app/service/client/api_client.py index adc4865..56d2a00 100644 --- a/app/service/client/api_client.py +++ b/app/service/client/api_client.py @@ -161,6 +161,80 @@ class GeminiApiClient(ApiClient): raise Exception(f"API call failed with status code {response.status_code}, {error_content}") return response.json() + async def embed_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + """单一嵌入内容生成""" + timeout = httpx.Timeout(self.timeout, read=self.timeout) + model = self._get_real_model(model) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for embedding: {proxy_to_use}") + + headers = self._prepare_headers() + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/models/{model}:embedContent?key={api_key}" + + try: + response = await client.post(url, json=payload, headers=headers) + + if response.status_code != 200: + error_content = response.text + logger.error(f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}") + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + + return response.json() + + except httpx.TimeoutException as e: + logger.error(f"Embedding request timeout: {e}") + raise Exception(f"Request timeout: {e}") + except httpx.RequestError as e: + logger.error(f"Embedding request error: {e}") + raise Exception(f"Request error: {e}") + except Exception as e: + logger.error(f"Unexpected embedding error: {e}") + raise + + async def batch_embed_contents(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + """批量嵌入内容生成""" + timeout = httpx.Timeout(self.timeout, read=self.timeout) + model = self._get_real_model(model) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for batch embedding: {proxy_to_use}") + + headers = self._prepare_headers() + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/models/{model}:batchEmbedContents?key={api_key}" + + try: + response = await client.post(url, json=payload, headers=headers) + + if response.status_code != 200: + error_content = response.text + logger.error(f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}") + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + + return response.json() + + except httpx.TimeoutException as e: + logger.error(f"Batch embedding request timeout: {e}") + raise Exception(f"Request timeout: {e}") + except httpx.RequestError as e: + logger.error(f"Batch embedding request error: {e}") + raise Exception(f"Request error: {e}") + except Exception as e: + logger.error(f"Unexpected batch embedding error: {e}") + raise + class OpenaiApiClient(ApiClient): """OpenAI API客户端""" diff --git a/app/service/embedding/gemini_embedding_service.py b/app/service/embedding/gemini_embedding_service.py new file mode 100644 index 0000000..5628a20 --- /dev/null +++ b/app/service/embedding/gemini_embedding_service.py @@ -0,0 +1,148 @@ +# app/service/embedding/gemini_embedding_service.py + +import datetime +import re +import time +from typing import Any, Dict + +from app.config.config import settings +from app.database.services import add_error_log, add_request_log +from app.domain.gemini_models import GeminiBatchEmbedRequest, GeminiEmbedRequest +from app.log.logger import get_gemini_embedding_logger +from app.service.client.api_client import GeminiApiClient +from app.service.key.key_manager import KeyManager + +logger = get_gemini_embedding_logger() + + +def _build_embed_payload(request: GeminiEmbedRequest) -> Dict[str, Any]: + """构建嵌入请求payload""" + payload = {"content": request.content.model_dump()} + + if request.taskType: + payload["taskType"] = request.taskType + if request.title: + payload["title"] = request.title + if request.outputDimensionality: + payload["outputDimensionality"] = request.outputDimensionality + + return payload + + +def _build_batch_embed_payload( + request: GeminiBatchEmbedRequest, model: str +) -> Dict[str, Any]: + """构建批量嵌入请求payload""" + requests = [] + for embed_request in request.requests: + embed_payload = _build_embed_payload(embed_request) + embed_payload["model"] = ( + f"models/{model}" # Gemini API要求每个请求包含model字段 + ) + requests.append(embed_payload) + + return {"requests": requests} + + +class GeminiEmbeddingService: + """Gemini嵌入服务""" + + def __init__(self, base_url: str, key_manager: KeyManager): + self.api_client = GeminiApiClient(base_url, settings.TIME_OUT) + self.key_manager = key_manager + + async def embed_content( + self, model: str, request: GeminiEmbedRequest, api_key: str + ) -> Dict[str, Any]: + """生成单一嵌入内容""" + payload = _build_embed_payload(request) + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + + try: + response = await self.api_client.embed_content(payload, model, api_key) + is_success = True + status_code = 200 + return response + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Single embedding API call failed: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="gemini-embed-single", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload, + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) + + async def batch_embed_contents( + self, model: str, request: GeminiBatchEmbedRequest, api_key: str + ) -> Dict[str, Any]: + """生成批量嵌入内容""" + payload = _build_batch_embed_payload(request, model) + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + + try: + response = await self.api_client.batch_embed_contents( + payload, model, api_key + ) + is_success = True + status_code = 200 + return response + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Batch embedding API call failed: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="gemini-embed-batch", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload, + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + )