mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 06:11:32 +08:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd1fa35c73 | ||
|
|
fb572fa849 | ||
|
|
c0a473ed19 | ||
|
|
030641adc6 | ||
|
|
445ef49dc8 | ||
|
|
32d4c60541 | ||
|
|
23f865be07 | ||
|
|
5d55325c12 | ||
|
|
900330509a | ||
|
|
cfb682ae3c | ||
|
|
abae90b16d | ||
|
|
470fc37f26 | ||
|
|
7a7caef1a6 | ||
|
|
a6aecb5d89 | ||
|
|
4a004f9aa1 | ||
|
|
1a6feae23b | ||
|
|
af5b2fa2c9 | ||
|
|
eeec45274b | ||
|
|
2b48c853fe | ||
|
|
c47f696691 | ||
|
|
9a8e4c8e15 | ||
|
|
24aab9a658 | ||
|
|
afdaaffac5 | ||
|
|
fe721116e2 | ||
|
|
8e0a834daa | ||
|
|
c9fca1561c | ||
|
|
5eb2dfd822 |
10
.env.example
10
.env.example
@@ -79,4 +79,12 @@ URL_NORMALIZATION_ENABLED=false
|
||||
# tts配置
|
||||
TTS_MODEL=gemini-2.5-flash-preview-tts
|
||||
TTS_VOICE_NAME=Zephyr
|
||||
TTS_SPEED=normal
|
||||
TTS_SPEED=normal
|
||||
#########################Files API 相关配置########################
|
||||
# 是否启用文件过期自动清理
|
||||
FILES_CLEANUP_ENABLED=true
|
||||
# 文件过期清理间隔(小时)
|
||||
FILES_CLEANUP_INTERVAL_HOURS=1
|
||||
# 是否启用用户文件隔离(每个用户只能看到自己上传的文件)
|
||||
FILES_USER_ISOLATION_ENABLED=true
|
||||
##########################################################################
|
||||
@@ -8,14 +8,6 @@ COPY ./VERSION /app
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY ./app /app/app
|
||||
ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
||||
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
||||
ENV URL_NORMALIZATION_ENABLED=false
|
||||
ENV CLOUDFLARE_IMGBED_UPLOAD_FOLDER=""
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
@@ -123,6 +123,10 @@ class Settings(BaseSettings):
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
|
||||
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
|
||||
|
||||
# Files API
|
||||
FILES_CLEANUP_ENABLED: bool = True
|
||||
FILES_CLEANUP_INTERVAL_HOURS: int = 1
|
||||
FILES_USER_ISOLATION_ENABLED: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -15,12 +15,12 @@ DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TOP_K = 40
|
||||
DEFAULT_FILTER_MODELS = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001",
|
||||
]
|
||||
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
|
||||
# 图像生成相关常量
|
||||
@@ -38,14 +38,14 @@ DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||
|
||||
# 正则表达式模式
|
||||
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||
IMAGE_URL_PATTERN = r"!\[(.*?)\]\((.*?)\)"
|
||||
DATA_URL_PATTERN = r"data:([^;]+);base64,(.+)"
|
||||
|
||||
# Audio/Video Settings
|
||||
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
||||
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
||||
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
|
||||
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
||||
AUDIO_FORMAT_TO_MIMETYPE = {
|
||||
@@ -63,28 +63,50 @@ VIDEO_FORMAT_TO_MIMETYPE = {
|
||||
}
|
||||
|
||||
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
|
||||
DEFAULT_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
TTS_VOICE_NAMES = [
|
||||
"Zephyr", "Puck", "Charon", "Kore",
|
||||
"Fenrir", "Leda", "Orus", "Aoede",
|
||||
"Callirhoe", "Autonoe", "Enceladus", "Iapetus",
|
||||
"Umbriel", "Algieba", "Despina", "Erinome",
|
||||
"Algenib", "Rasalgethi", "Laomedeia", "Achernar",
|
||||
"Alnilam", "Schedar", "Gacrux", "Pulcherrima",
|
||||
"Achird", "Zubenelgenubi", "Vindemiatrix", "Sadachbia",
|
||||
"Sadaltager", "Sulafat"
|
||||
]
|
||||
"Zephyr",
|
||||
"Puck",
|
||||
"Charon",
|
||||
"Kore",
|
||||
"Fenrir",
|
||||
"Leda",
|
||||
"Orus",
|
||||
"Aoede",
|
||||
"Callirrhoe",
|
||||
"Autonoe",
|
||||
"Enceladus",
|
||||
"Iapetus",
|
||||
"Umbriel",
|
||||
"Algieba",
|
||||
"Despina",
|
||||
"Erinome",
|
||||
"Algenib",
|
||||
"Rasalgethi",
|
||||
"Laomedeia",
|
||||
"Achernar",
|
||||
"Alnilam",
|
||||
"Schedar",
|
||||
"Gacrux",
|
||||
"Pulcherrima",
|
||||
"Achird",
|
||||
"Zubenelgenubi",
|
||||
"Vindemiatrix",
|
||||
"Sadachbia",
|
||||
"Sadaltager",
|
||||
"Sulafat",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
数据库模型模块
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
|
||||
import enum
|
||||
|
||||
from app.database.connection import Base
|
||||
|
||||
@@ -60,3 +61,69 @@ class RequestLog(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
||||
|
||||
|
||||
class FileState(enum.Enum):
|
||||
"""文件状态枚举"""
|
||||
PROCESSING = "PROCESSING"
|
||||
ACTIVE = "ACTIVE"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
"""
|
||||
文件记录表,用于存储上传到 Gemini 的文件信息
|
||||
"""
|
||||
__tablename__ = "t_file_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 文件基本信息
|
||||
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
|
||||
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
|
||||
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
|
||||
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
|
||||
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
|
||||
|
||||
# 状态信息
|
||||
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
|
||||
|
||||
# 时间戳
|
||||
create_time = Column(DateTime, nullable=False, comment="创建时间")
|
||||
update_time = Column(DateTime, nullable=False, comment="更新时间")
|
||||
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
|
||||
|
||||
# API 相关
|
||||
uri = Column(String(500), nullable=False, comment="文件访问 URI")
|
||||
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
|
||||
upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)")
|
||||
|
||||
# 额外信息
|
||||
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
|
||||
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式,用于 API 响应"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"displayName": self.display_name,
|
||||
"mimeType": self.mime_type,
|
||||
"sizeBytes": str(self.size_bytes),
|
||||
"createTime": self.create_time.isoformat() + "Z",
|
||||
"updateTime": self.update_time.isoformat() + "Z",
|
||||
"expirationTime": self.expiration_time.isoformat() + "Z",
|
||||
"sha256Hash": self.sha256_hash,
|
||||
"uri": self.uri,
|
||||
"state": self.state.value if self.state else "PROCESSING"
|
||||
}
|
||||
|
||||
def is_expired(self):
|
||||
"""检查文件是否已过期"""
|
||||
# 确保比较时都是 timezone-aware
|
||||
expiration_time = self.expiration_time
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
|
||||
return datetime.datetime.now(datetime.timezone.utc) > expiration_time
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
数据库服务模块
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
||||
import json
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings, ErrorLog, RequestLog
|
||||
from app.database.models import Settings, ErrorLog, RequestLog, FileRecord, FileState
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
@@ -427,3 +427,264 @@ async def add_request_log(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add request log: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 文件记录相关函数 ====================
|
||||
|
||||
async def create_file_record(
|
||||
name: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
api_key: str,
|
||||
uri: str,
|
||||
create_time: datetime,
|
||||
update_time: datetime,
|
||||
expiration_time: datetime,
|
||||
state: FileState = FileState.PROCESSING,
|
||||
display_name: Optional[str] = None,
|
||||
sha256_hash: Optional[str] = None,
|
||||
upload_url: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
mime_type: MIME 类型
|
||||
size_bytes: 文件大小(字节)
|
||||
api_key: 上传时使用的 API Key
|
||||
uri: 文件访问 URI
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
expiration_time: 过期时间
|
||||
display_name: 显示名称
|
||||
sha256_hash: SHA256 哈希值
|
||||
upload_url: 临时上传 URL
|
||||
user_token: 上传用户的 token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 创建的文件记录
|
||||
"""
|
||||
try:
|
||||
query = insert(FileRecord).values(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
mime_type=mime_type,
|
||||
size_bytes=size_bytes,
|
||||
sha256_hash=sha256_hash,
|
||||
state=state,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
expiration_time=expiration_time,
|
||||
uri=uri,
|
||||
api_key=api_key,
|
||||
upload_url=upload_url,
|
||||
user_token=user_token
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
# 返回创建的记录
|
||||
return await get_file_record_by_name(name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据文件名获取文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord).where(FileRecord.name == name)
|
||||
result = await database.fetch_one(query)
|
||||
return dict(result) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file record by name {name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def update_file_record_state(
|
||||
file_name: str,
|
||||
state: FileState,
|
||||
update_time: Optional[datetime] = None,
|
||||
upload_completed: Optional[datetime] = None,
|
||||
sha256_hash: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新文件记录状态
|
||||
|
||||
Args:
|
||||
file_name: 文件名
|
||||
state: 新状态
|
||||
update_time: 更新时间
|
||||
upload_completed: 上传完成时间
|
||||
sha256_hash: SHA256 哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
values = {"state": state}
|
||||
if update_time:
|
||||
values["update_time"] = update_time
|
||||
if upload_completed:
|
||||
values["upload_completed"] = upload_completed
|
||||
if sha256_hash:
|
||||
values["sha256_hash"] = sha256_hash
|
||||
|
||||
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
|
||||
result = await database.execute(query)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated file record state for {file_name} to {state}")
|
||||
return True
|
||||
|
||||
logger.warning(f"File record not found for update: {file_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file record state: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def list_file_records(
|
||||
user_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
列出文件记录
|
||||
|
||||
Args:
|
||||
user_token: 用户 token(如果提供,只返回该用户的文件)
|
||||
api_key: API Key(如果提供,只返回使用该 key 的文件)
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记(偏移量)
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_file_records called with page_size={page_size}, page_token={page_token}")
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
if user_token:
|
||||
query = query.where(FileRecord.user_token == user_token)
|
||||
if api_key:
|
||||
query = query.where(FileRecord.api_key == api_key)
|
||||
|
||||
# 使用偏移量进行分页
|
||||
offset = 0
|
||||
if page_token:
|
||||
try:
|
||||
offset = int(page_token)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid page token: {page_token}")
|
||||
offset = 0
|
||||
|
||||
# 按ID升序排列,使用 OFFSET 和 LIMIT
|
||||
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
logger.debug(f"Query returned {len(results)} records")
|
||||
if results:
|
||||
logger.debug(f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}")
|
||||
|
||||
# 处理分页
|
||||
has_next = len(results) > page_size
|
||||
if has_next:
|
||||
results = results[:page_size]
|
||||
# 下一页的偏移量是当前偏移量加上本页返回的记录数
|
||||
next_offset = offset + page_size
|
||||
next_page_token = str(next_offset)
|
||||
logger.debug(f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}")
|
||||
else:
|
||||
next_page_token = None
|
||||
logger.debug(f"No next page, returning {len(results)} results")
|
||||
|
||||
return [dict(row) for row in results], next_page_token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def delete_file_record(name: str) -> bool:
|
||||
"""
|
||||
删除文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
query = delete(FileRecord).where(FileRecord.name == name)
|
||||
await database.execute(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file record: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_expired_file_records() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
删除已过期的文件记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 删除的记录列表
|
||||
"""
|
||||
try:
|
||||
# 先获取要删除的记录
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
expired_records = await database.fetch_all(query)
|
||||
|
||||
if not expired_records:
|
||||
return []
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Deleted {len(expired_records)} expired file records")
|
||||
return [dict(record) for record in expired_records]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete expired file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_api_key(name: str) -> Optional[str]:
|
||||
"""
|
||||
获取文件对应的 API Key
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: API Key,如果文件不存在或已过期则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord.api_key).where(
|
||||
(FileRecord.name == name) &
|
||||
(FileRecord.expiration_time > datetime.now(timezone.utc))
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
return result["api_key"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file API key: {str(e)}")
|
||||
raise
|
||||
|
||||
69
app/domain/file_models.py
Normal file
69
app/domain/file_models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Files API 相关的领域模型
|
||||
"""
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""文件上传配置"""
|
||||
mime_type: Optional[str] = Field(None, description="MIME 类型")
|
||||
display_name: Optional[str] = Field(None, description="显示名称,最多40个字符")
|
||||
|
||||
|
||||
class CreateFileRequest(BaseModel):
|
||||
"""创建文件请求(用于初始化上传)"""
|
||||
file: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
|
||||
|
||||
|
||||
class FileMetadata(BaseModel):
|
||||
"""文件元数据响应"""
|
||||
name: str = Field(..., description="文件名称,格式: files/{file_id}")
|
||||
displayName: Optional[str] = Field(None, description="显示名称")
|
||||
mimeType: str = Field(..., description="MIME 类型")
|
||||
sizeBytes: str = Field(..., description="文件大小(字节)")
|
||||
createTime: str = Field(..., description="创建时间 (RFC3339)")
|
||||
updateTime: str = Field(..., description="更新时间 (RFC3339)")
|
||||
expirationTime: str = Field(..., description="过期时间 (RFC3339)")
|
||||
sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值")
|
||||
uri: str = Field(..., description="文件访问 URI")
|
||||
state: str = Field(..., description="文件状态")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() + "Z"
|
||||
}
|
||||
|
||||
|
||||
class ListFilesRequest(BaseModel):
|
||||
"""列出文件请求参数"""
|
||||
pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小")
|
||||
pageToken: Optional[str] = Field(None, description="分页标记")
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
"""列出文件响应"""
|
||||
files: List[FileMetadata] = Field(default_factory=list, description="文件列表")
|
||||
nextPageToken: Optional[str] = Field(None, description="下一页标记")
|
||||
|
||||
|
||||
class UploadInitResponse(BaseModel):
|
||||
"""上传初始化响应(内部使用)"""
|
||||
file_metadata: FileMetadata
|
||||
upload_url: str
|
||||
|
||||
|
||||
class FileKeyMapping(BaseModel):
|
||||
"""文件与 API Key 的映射关系(内部使用)"""
|
||||
file_name: str
|
||||
api_key: str
|
||||
user_token: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
"""删除文件响应"""
|
||||
success: bool = Field(..., description="是否删除成功")
|
||||
message: Optional[str] = Field(None, description="消息")
|
||||
@@ -41,6 +41,9 @@ class GenerationConfig(BaseModel):
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
thinkingConfig: Optional[Dict[str, Any]] = None
|
||||
# TTS相关字段
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
|
||||
@@ -9,6 +9,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
from app.log.logger import get_openai_logger
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
@@ -159,13 +162,16 @@ def _extract_result(
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
|
||||
text, reasoning_content, tool_calls, thought = "", "", [], None
|
||||
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
if not parts:
|
||||
logger.warning("No parts found in stream response")
|
||||
return "", None, [], None
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
if "thought" in parts[0]:
|
||||
@@ -191,24 +197,38 @@ def _extract_result(
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
text, reasoning_content = "", ""
|
||||
if "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
if "thought" in part and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content += part["text"]
|
||||
else:
|
||||
text += part["text"]
|
||||
if "thought" in part and thought is None:
|
||||
thought = part.get("thought")
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
|
||||
# 使用安全的访问方式
|
||||
content = candidate.get("content", {})
|
||||
|
||||
if content and isinstance(content, dict):
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts:
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
if "thought" in part and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content += part["text"]
|
||||
else:
|
||||
text += part["text"]
|
||||
if "thought" in part and thought is None:
|
||||
thought = part.get("thought")
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
else:
|
||||
logger.warning(f"No parts found in content for model: {model}")
|
||||
else:
|
||||
logger.error(f"Invalid content structure for model: {model}")
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(
|
||||
candidate["content"]["parts"], gemini_format
|
||||
)
|
||||
|
||||
# 安全地获取 parts 用于工具调用提取
|
||||
parts = candidate.get("content", {}).get("parts", [])
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
logger.warning(f"No candidates found in response for model: {model}")
|
||||
text = "暂无返回"
|
||||
|
||||
return text, reasoning_content, tool_calls, thought
|
||||
|
||||
|
||||
@@ -250,8 +270,8 @@ def _extract_tool_calls(
|
||||
return []
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
|
||||
tool_calls = list()
|
||||
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
@@ -260,7 +280,7 @@ def _extract_tool_calls(
|
||||
item = part.get("functionCall", {})
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
|
||||
if gemini_format:
|
||||
tool_calls.append(part)
|
||||
else:
|
||||
|
||||
@@ -228,6 +228,10 @@ def get_request_log_logger():
|
||||
return Logger.setup_logger("request_log")
|
||||
|
||||
|
||||
def get_files_logger():
|
||||
return Logger.setup_logger("files")
|
||||
|
||||
|
||||
def get_vertex_express_logger():
|
||||
return Logger.setup_logger("vertex_express")
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
and not request.url.path.startswith("/openai")
|
||||
and not request.url.path.startswith("/api/version/check")
|
||||
and not request.url.path.startswith("/vertex-express")
|
||||
and not request.url.path.startswith("/upload")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
|
||||
295
app/router/files_routes.py
Normal file
295
app/router/files_routes.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Files API 路由
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request, Query, Depends, Header, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.file_models import (
|
||||
FileMetadata,
|
||||
ListFilesResponse,
|
||||
DeleteFileResponse
|
||||
)
|
||||
from app.log.logger import get_files_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.service.files.file_upload_handler import get_upload_handler
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
router = APIRouter()
|
||||
security_service = SecurityService()
|
||||
|
||||
|
||||
@router.post("/upload/v1beta/files")
|
||||
async def upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传"""
|
||||
logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}")
|
||||
|
||||
# 檢查是否是實際的上傳請求(有 upload_id)
|
||||
if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]:
|
||||
logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.")
|
||||
return await handle_upload(
|
||||
upload_path="v1beta/files",
|
||||
request=request,
|
||||
key=request.query_params.get("key"),
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 获取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 构建请求主机 URL
|
||||
request_host = f"{request.url.scheme}://{request.url.netloc}"
|
||||
logger.info(f"Request host: {request_host}")
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"x-goog-upload-protocol": x_goog_upload_protocol or "resumable",
|
||||
"x-goog-upload-command": x_goog_upload_command or "start",
|
||||
}
|
||||
|
||||
if x_goog_upload_header_content_length:
|
||||
headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length
|
||||
if x_goog_upload_header_content_type:
|
||||
headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type
|
||||
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
response_data, response_headers = await files_service.initialize_upload(
|
||||
headers=headers,
|
||||
body=body,
|
||||
user_token=user_token,
|
||||
request_host=request_host # 傳遞請求主機
|
||||
)
|
||||
|
||||
logger.info(f"Upload initialization response: {response_data}")
|
||||
logger.info(f"Upload initialization response headers: {response_headers}")
|
||||
|
||||
logger.info(f"Upload initialization response headers: {response_data}")
|
||||
# 返回响应
|
||||
return JSONResponse(
|
||||
content=response_data,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload initialization failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload initialization: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files")
|
||||
async def list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件"""
|
||||
logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token(如果启用用户隔离)
|
||||
user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.list_files(
|
||||
page_size=page_size,
|
||||
page_token=page_token,
|
||||
user_token=user_token
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"List files failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in list files: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files/{file_id:path}")
|
||||
async def get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息"""
|
||||
logger.debug(f"Get file request: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.get_file(f"files/{file_id}", user_token)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Get file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/v1beta/files/{file_id:path}")
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件"""
|
||||
logger.info(f"Delete file: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
success = await files_service.delete_file(f"files/{file_id}", user_token)
|
||||
|
||||
return DeleteFileResponse(
|
||||
success=success,
|
||||
message="File deleted successfully" if success else "Failed to delete file"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Delete file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in delete file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 处理上传请求的通配符路由
|
||||
@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"])
|
||||
async def handle_upload(
|
||||
upload_path: str,
|
||||
request: Request,
|
||||
key: Optional[str] = Query(None), # 從查詢參數獲取 key
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
):
|
||||
"""处理文件上传请求"""
|
||||
try:
|
||||
logger.info(f"Handling upload request: {request.method} {upload_path}, key={key}")
|
||||
|
||||
# 從查詢參數獲取 upload_id
|
||||
upload_id = request.query_params.get("upload_id")
|
||||
if not upload_id:
|
||||
raise HTTPException(status_code=400, detail="Missing upload_id")
|
||||
|
||||
# 從 session 獲取真實的 API key
|
||||
files_service = await get_files_service()
|
||||
session_info = await files_service.get_upload_session(upload_id)
|
||||
if not session_info:
|
||||
logger.error(f"No session found for upload_id: {upload_id}")
|
||||
raise HTTPException(status_code=404, detail="Upload session not found")
|
||||
|
||||
real_api_key = session_info["api_key"]
|
||||
original_upload_url = session_info["upload_url"]
|
||||
|
||||
# 使用真實的 API key 構建完整的 Google 上傳 URL
|
||||
# 保留原始 URL 的所有參數,但使用真實的 API key
|
||||
upload_url = original_upload_url
|
||||
logger.info(f"Using real API key for upload: {real_api_key[:8]}...{real_api_key[-4:]}")
|
||||
|
||||
# 代理上传请求
|
||||
upload_handler = get_upload_handler()
|
||||
return await upload_handler.proxy_upload_request(
|
||||
request=request,
|
||||
upload_url=upload_url,
|
||||
files_service=files_service
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload handling failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload handling: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 为兼容性添加 /gemini 前缀的路由
|
||||
@router.post("/gemini/upload/v1beta/files")
|
||||
async def gemini_upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传(Gemini 前缀)"""
|
||||
return await upload_file_init(
|
||||
request,
|
||||
auth_token,
|
||||
x_goog_upload_protocol,
|
||||
x_goog_upload_command,
|
||||
x_goog_upload_header_content_length,
|
||||
x_goog_upload_header_content_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files")
|
||||
async def gemini_list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件(Gemini 前缀)"""
|
||||
return await list_files(page_size, page_token, auth_token)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息(Gemini 前缀)"""
|
||||
return await get_file(file_id, auth_token)
|
||||
|
||||
|
||||
@router.delete("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件(Gemini 前缀)"""
|
||||
return await delete_file(file_id, auth_token)
|
||||
@@ -8,6 +8,7 @@ from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
@@ -109,11 +110,41 @@ async def generate_content(
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
# 检测是否为原生Gemini TTS请求
|
||||
is_native_tts = False
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
is_native_tts = True
|
||||
logger.info("Detected native Gemini TTS request")
|
||||
logger.info(f"TTS responseModalities: {response_modalities}")
|
||||
logger.info(f"TTS speechConfig: {speech_config}")
|
||||
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
# 所有原生TTS请求都使用TTS增强服务
|
||||
if is_native_tts:
|
||||
try:
|
||||
logger.info("Using native TTS enhanced service")
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"Native TTS processing failed, falling back to standard service: {e}")
|
||||
|
||||
# 使用标准服务处理所有其他请求(非TTS)
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
|
||||
@@ -8,7 +8,7 @@ from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes, files_routes
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats.stats_service import StatsService
|
||||
|
||||
@@ -34,6 +34,7 @@ def setup_routers(app: FastAPI) -> None:
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
app.include_router(vertex_express_routes.router)
|
||||
app.include_router(files_routes.router)
|
||||
|
||||
setup_page_routes(app)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.error_log.error_log_service import delete_old_error_logs
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
||||
from app.service.files.files_service import get_files_service
|
||||
|
||||
logger = Logger.setup_logger("scheduler")
|
||||
|
||||
@@ -96,6 +97,26 @@ async def check_failed_keys():
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_expired_files():
|
||||
"""
|
||||
定时清理过期的文件记录
|
||||
"""
|
||||
logger.info("Starting scheduled cleanup for expired files...")
|
||||
try:
|
||||
files_service = await get_files_service()
|
||||
deleted_count = await files_service.cleanup_expired_files()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
|
||||
else:
|
||||
logger.info("No expired files to clean up.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
@@ -134,6 +155,20 @@ def setup_scheduler():
|
||||
logger.info(
|
||||
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
||||
)
|
||||
|
||||
# 新增:添加文件过期清理的定时任务,每小时执行一次
|
||||
if getattr(settings, 'FILES_CLEANUP_ENABLED', True):
|
||||
cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1)
|
||||
scheduler.add_job(
|
||||
cleanup_expired_files,
|
||||
"interval",
|
||||
hours=cleanup_interval,
|
||||
id="cleanup_expired_files_job",
|
||||
name="Cleanup Expired Files",
|
||||
)
|
||||
logger.info(
|
||||
f"File cleanup job scheduled to run every {cleanup_interval} hour(s)."
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started with all jobs.")
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.database.services import add_error_log, add_request_log, get_file_api_key
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -27,6 +27,28 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
|
||||
"""從內容中提取文件引用"""
|
||||
file_names = []
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if not isinstance(part, dict) or "fileData" not in part:
|
||||
continue
|
||||
file_data = part["fileData"]
|
||||
if "fileUri" not in file_data:
|
||||
continue
|
||||
file_uri = file_data["fileUri"]
|
||||
# 從 URI 中提取文件名
|
||||
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
|
||||
match = re.match(rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri)
|
||||
if not match:
|
||||
logger.warning(f"Invalid file URI: {file_uri}")
|
||||
continue
|
||||
file_id = match.group(1)
|
||||
file_names.append(file_id)
|
||||
logger.info(f"Found file reference: {file_id}")
|
||||
return file_names
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
@@ -135,19 +157,34 @@ def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(by_alias=True, exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
# 检查是否为TTS模型
|
||||
is_tts_model = "tts" in model.lower()
|
||||
|
||||
if is_tts_model:
|
||||
# TTS模型使用简化的payload,不包含tools和safetySettings
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
else:
|
||||
# 非TTS模型使用完整的payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
@@ -217,6 +254,17 @@ class GeminiChatService:
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
payload = _build_payload(model, request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
@@ -312,6 +360,17 @@ class GeminiChatService:
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
payload = _build_payload(model, request)
|
||||
|
||||
@@ -142,6 +142,23 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _validate_and_set_max_tokens(
|
||||
payload: Dict[str, Any],
|
||||
max_tokens: Optional[int],
|
||||
logger_instance
|
||||
) -> None:
|
||||
"""验证并设置 max_tokens 参数"""
|
||||
if max_tokens is None:
|
||||
return
|
||||
|
||||
# 参数验证和处理
|
||||
if max_tokens <= 0:
|
||||
logger_instance.warning(f"Invalid max_tokens value: {max_tokens}, will not set maxOutputTokens")
|
||||
# 不设置 maxOutputTokens,让 Gemini API 使用默认值
|
||||
else:
|
||||
payload["generationConfig"]["maxOutputTokens"] = max_tokens
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -159,12 +176,16 @@ def _build_payload(
|
||||
"tools": _build_tools(request, messages),
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.max_tokens is not None:
|
||||
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||||
|
||||
# 处理 max_tokens 参数
|
||||
_validate_and_set_max_tokens(payload, request.max_tokens, logger)
|
||||
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if request.model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
|
||||
if request.model in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
@@ -239,27 +260,53 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
usage_metadata = response.get("usageMetadata", {})
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
|
||||
# 尝试处理响应,捕获可能的响应处理异常
|
||||
try:
|
||||
result = self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
return result
|
||||
except Exception as response_error:
|
||||
logger.error(f"Response processing failed for model {model}: {str(response_error)}")
|
||||
|
||||
# 记录详细的错误信息
|
||||
if "parts" in str(response_error):
|
||||
logger.error("Response structure issue - missing or invalid parts")
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
logger.error(f"Content structure: {content}")
|
||||
|
||||
# 重新抛出异常
|
||||
raise response_error
|
||||
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
logger.error(f"API call failed for model {model}: {error_log_msg}")
|
||||
|
||||
# 特别记录 max_tokens 相关的错误
|
||||
gen_config = payload.get('generationConfig', {})
|
||||
if "maxOutputTokens" in gen_config:
|
||||
logger.error(f"Request had maxOutputTokens: {gen_config['maxOutputTokens']}")
|
||||
|
||||
# 如果是响应处理错误,记录更多信息
|
||||
if "parts" in error_log_msg:
|
||||
logger.error("This is likely a response processing error")
|
||||
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
status_code = int(match.group(1)) if match else 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -273,6 +320,8 @@ class OpenAIChatService:
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Normal completion finished - Success: {is_success}, Latency: {latency_ms}ms")
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
|
||||
@@ -115,7 +115,7 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(by_alias=True, exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
|
||||
@@ -77,7 +77,7 @@ class GeminiApiClient(ApiClient):
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -87,13 +87,35 @@ class GeminiApiClient(ApiClient):
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应结构的基本信息
|
||||
if not response_data.get("candidates"):
|
||||
logger.warning("No candidates found in API response")
|
||||
|
||||
return response_data
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
1
app/service/files/__init__.py
Normal file
1
app/service/files/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Intentionally empty __init__.py file
|
||||
247
app/service/files/file_upload_handler.py
Normal file
247
app/service/files/file_upload_handler.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
文件上传处理器
|
||||
处理 Google 的可恢复上传协议
|
||||
"""
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import Request, Response, HTTPException
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.log.logger import get_files_logger
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
|
||||
class FileUploadHandler:
|
||||
"""处理文件分块上传"""
|
||||
|
||||
def __init__(self):
|
||||
self.chunk_size = 8 * 1024 * 1024 # 8MB
|
||||
|
||||
async def handle_upload_chunk(
|
||||
self,
|
||||
upload_url: str,
|
||||
request: Request,
|
||||
files_service=None # 添加 files_service 參數
|
||||
) -> Response:
|
||||
"""
|
||||
处理上传分块
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
request: FastAPI 请求对象
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 响应对象
|
||||
"""
|
||||
try:
|
||||
# 获取请求头
|
||||
headers = {}
|
||||
|
||||
# 复制必要的上传头
|
||||
upload_headers = [
|
||||
"x-goog-upload-command",
|
||||
"x-goog-upload-offset",
|
||||
"content-type",
|
||||
"content-length"
|
||||
]
|
||||
|
||||
for header in upload_headers:
|
||||
if header in request.headers:
|
||||
# 转换为正确的格式
|
||||
key = "-".join(word.capitalize() for word in header.split("-"))
|
||||
headers[key] = request.headers[header]
|
||||
|
||||
# 读取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 检查是否是最后一块
|
||||
is_final = "finalize" in headers.get("X-Goog-Upload-Command", "")
|
||||
logger.debug(f"Upload command: {headers.get('X-Goog-Upload-Command', '')}, is_final: {is_final}")
|
||||
|
||||
# 转发到真实的上传 URL
|
||||
async with AsyncClient() as client:
|
||||
response = await client.post(
|
||||
upload_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300.0 # 5分钟超时
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201, 308]:
|
||||
logger.error(f"Upload chunk failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload failed")
|
||||
|
||||
# 如果是最后一块,更新文件状态
|
||||
if is_final and response.status_code in [200, 201]:
|
||||
logger.debug(f"Upload finalized with status {response.status_code}")
|
||||
try:
|
||||
# 解析響應獲取文件信息
|
||||
response_data = response.json()
|
||||
logger.debug(f"Upload complete response data: {response_data}")
|
||||
file_data = response_data.get("file", {})
|
||||
|
||||
# 獲取真實的文件名
|
||||
real_file_name = file_data.get("name")
|
||||
logger.debug(f"Upload response: {response_data}")
|
||||
if real_file_name and files_service:
|
||||
logger.info(f"Upload completed, file name: {real_file_name}")
|
||||
|
||||
# 從會話中獲取信息
|
||||
session_info = await files_service.get_upload_session(upload_url)
|
||||
logger.debug(f"Retrieved session info for {upload_url}: {session_info}")
|
||||
if session_info:
|
||||
# 創建文件記錄
|
||||
now = datetime.now(timezone.utc)
|
||||
expiration_time = now + timedelta(hours=48)
|
||||
|
||||
# 處理過期時間格式(Google 可能返回納秒級精度)
|
||||
expiration_time_str = file_data.get("expirationTime", expiration_time.isoformat() + "Z")
|
||||
# 處理納秒格式:2025-07-11T02:02:52.531916141Z -> 2025-07-11T02:02:52.531916Z
|
||||
if expiration_time_str.endswith("Z"):
|
||||
# 移除 Z
|
||||
expiration_time_str = expiration_time_str[:-1]
|
||||
# 如果有納秒(超過6位小數),截斷到微秒
|
||||
if "." in expiration_time_str:
|
||||
date_part, frac_part = expiration_time_str.rsplit(".", 1)
|
||||
if len(frac_part) > 6:
|
||||
frac_part = frac_part[:6]
|
||||
expiration_time_str = f"{date_part}.{frac_part}"
|
||||
# 添加時區
|
||||
expiration_time_str += "+00:00"
|
||||
|
||||
# 獲取文件狀態(Google 可能返回 PROCESSING)
|
||||
file_state = file_data.get("state", "PROCESSING")
|
||||
logger.debug(f"File state from Google: {file_state}")
|
||||
|
||||
# 將字符串狀態轉換為枚舉
|
||||
if file_state == "ACTIVE":
|
||||
state_enum = FileState.ACTIVE
|
||||
elif file_state == "PROCESSING":
|
||||
state_enum = FileState.PROCESSING
|
||||
elif file_state == "FAILED":
|
||||
state_enum = FileState.FAILED
|
||||
else:
|
||||
logger.warning(f"Unknown file state: {file_state}, defaulting to PROCESSING")
|
||||
state_enum = FileState.PROCESSING
|
||||
|
||||
await db_services.create_file_record(
|
||||
name=real_file_name,
|
||||
mime_type=file_data.get("mimeType", session_info["mime_type"]),
|
||||
size_bytes=int(file_data.get("sizeBytes", session_info["size_bytes"])),
|
||||
api_key=session_info["api_key"],
|
||||
uri=file_data.get("uri", f"{settings.BASE_URL}/{real_file_name}"),
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
expiration_time=datetime.fromisoformat(expiration_time_str),
|
||||
state=state_enum,
|
||||
display_name=file_data.get("displayName", session_info.get("display_name", "")),
|
||||
sha256_hash=file_data.get("sha256Hash"),
|
||||
user_token=session_info["user_token"]
|
||||
)
|
||||
logger.info(f"Created file record: name={real_file_name}, api_key={session_info['api_key'][:8]}...{session_info['api_key'][-4:]}")
|
||||
else:
|
||||
logger.warning(f"No upload session found for URL: {upload_url}")
|
||||
else:
|
||||
logger.warning(f"Missing real_file_name or files_service: real_file_name={real_file_name}, files_service={files_service}")
|
||||
|
||||
# 返回完整的文件信息
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}", exc_info=True)
|
||||
else:
|
||||
logger.debug(f"Upload chunk processed: is_final={is_final}, status={response.status_code}")
|
||||
|
||||
# 返回响应
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
# 确保包含必要的头
|
||||
if response.status_code == 308: # Resume Incomplete
|
||||
if "x-goog-upload-status" not in response_headers:
|
||||
response_headers["x-goog-upload-status"] = "active"
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle upload chunk: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def proxy_upload_request(
|
||||
self,
|
||||
request: Request,
|
||||
upload_url: str,
|
||||
files_service=None
|
||||
) -> Response:
|
||||
"""
|
||||
代理上传请求
|
||||
|
||||
Args:
|
||||
request: FastAPI 请求对象
|
||||
upload_url: 目标上传 URL
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 代理响应
|
||||
"""
|
||||
logger.debug(f"Proxy upload request: {request.method}, {upload_url}")
|
||||
try:
|
||||
# 如果是 GET 请求,返回上传状态
|
||||
if request.method == "GET":
|
||||
return await self._get_upload_status(upload_url)
|
||||
|
||||
# 处理 POST/PUT 请求
|
||||
return await self.handle_upload_chunk(upload_url, request, files_service)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to proxy upload request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _get_upload_status(self, upload_url: str) -> Response:
|
||||
"""
|
||||
获取上传状态
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
|
||||
Returns:
|
||||
Response: 状态响应
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(upload_url)
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upload status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
|
||||
# 单例实例
|
||||
_upload_handler_instance: Optional[FileUploadHandler] = None
|
||||
|
||||
|
||||
def get_upload_handler() -> FileUploadHandler:
|
||||
"""获取上传处理器单例实例"""
|
||||
global _upload_handler_instance
|
||||
if _upload_handler_instance is None:
|
||||
_upload_handler_instance = FileUploadHandler()
|
||||
return _upload_handler_instance
|
||||
498
app/service/files/files_service.py
Normal file
498
app/service/files/files_service.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
文件管理服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from httpx import AsyncClient
|
||||
import asyncio
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.domain.file_models import FileMetadata, ListFilesResponse
|
||||
from fastapi import HTTPException
|
||||
from app.log.logger import get_files_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
# 全局上傳會話存儲
|
||||
_upload_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
_upload_sessions_lock = asyncio.Lock()
|
||||
|
||||
|
||||
class FilesService:
|
||||
"""文件管理服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_client = GeminiApiClient(base_url=settings.BASE_URL)
|
||||
self.key_manager = None
|
||||
|
||||
async def _get_key_manager(self):
|
||||
"""获取 KeyManager 实例"""
|
||||
if not self.key_manager:
|
||||
self.key_manager = await get_key_manager_instance(
|
||||
settings.API_KEYS,
|
||||
settings.VERTEX_API_KEYS
|
||||
)
|
||||
return self.key_manager
|
||||
|
||||
async def initialize_upload(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
body: Optional[bytes],
|
||||
user_token: str,
|
||||
request_host: str = None # 添加請求主機參數
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
初始化文件上传
|
||||
|
||||
Args:
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], Dict[str, str]]: (响应体, 响应头)
|
||||
"""
|
||||
try:
|
||||
# 获取可用的 API key
|
||||
key_manager = await self._get_key_manager()
|
||||
api_key = await key_manager.get_next_key()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No available API keys")
|
||||
|
||||
# 转发请求到真实的 Gemini API
|
||||
async with AsyncClient() as client:
|
||||
# 准备请求头
|
||||
forward_headers = {
|
||||
"X-Goog-Upload-Protocol": headers.get("x-goog-upload-protocol", "resumable"),
|
||||
"X-Goog-Upload-Command": headers.get("x-goog-upload-command", "start"),
|
||||
"Content-Type": headers.get("content-type", "application/json"),
|
||||
}
|
||||
|
||||
# 添加其他必要的头
|
||||
if "x-goog-upload-header-content-length" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Length"] = headers["x-goog-upload-header-content-length"]
|
||||
if "x-goog-upload-header-content-type" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Type"] = headers["x-goog-upload-header-content-type"]
|
||||
|
||||
# 发送请求
|
||||
response = await client.post(
|
||||
"https://generativelanguage.googleapis.com/upload/v1beta/files",
|
||||
headers=forward_headers,
|
||||
content=body,
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Upload initialization failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload initialization failed")
|
||||
|
||||
# 获取上传 URL
|
||||
upload_url = response.headers.get("x-goog-upload-url")
|
||||
if not upload_url:
|
||||
raise HTTPException(status_code=500, detail="No upload URL in response")
|
||||
|
||||
logger.info(f"Original upload URL from Google: {upload_url}")
|
||||
|
||||
|
||||
# 儲存上傳資訊到 headers 中,供後續使用
|
||||
# 不在這裡創建數據庫記錄,等到上傳完成後再創建
|
||||
logger.info(f"Upload initialized with API key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
# 解析响应 - 初始化响应可能是空的
|
||||
response_data = {}
|
||||
|
||||
# 從請求體中解析文件信息(如果有)
|
||||
display_name = ""
|
||||
if body:
|
||||
try:
|
||||
request_data = json.loads(body)
|
||||
display_name = request_data.get("displayName", "")
|
||||
except Exception:
|
||||
pass
|
||||
# 從 upload URL 中提取 upload_id
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(upload_url)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
|
||||
if upload_id:
|
||||
# 儲存上傳會話信息,使用 upload_id 作為 key
|
||||
async with _upload_sessions_lock:
|
||||
_upload_sessions[upload_id] = {
|
||||
"api_key": api_key,
|
||||
"user_token": user_token,
|
||||
"display_name": display_name,
|
||||
"mime_type": headers.get("x-goog-upload-header-content-type", "application/octet-stream"),
|
||||
"size_bytes": int(headers.get("x-goog-upload-header-content-length", "0")),
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"upload_url": upload_url
|
||||
}
|
||||
logger.info(f"Stored upload session for upload_id={upload_id}: api_key={api_key[:8]}...{api_key[-4:]}")
|
||||
logger.debug(f"Total active sessions: {len(_upload_sessions)}")
|
||||
else:
|
||||
logger.warning(f"No upload_id found in upload URL: {upload_url}")
|
||||
|
||||
# 定期清理過期的會話(超過1小時)
|
||||
asyncio.create_task(self._cleanup_expired_sessions())
|
||||
|
||||
# 替換 Google 的 URL 為我們的代理 URL
|
||||
proxy_upload_url = upload_url
|
||||
if request_host:
|
||||
# 原始: https://generativelanguage.googleapis.com/upload/v1beta/files?key=AIzaSyDc...&upload_id=xxx&upload_protocol=resumable
|
||||
# 替換為: http://request-host/upload/v1beta/files?key=sk-123456&upload_id=xxx&upload_protocol=resumable
|
||||
|
||||
# 先替換域名
|
||||
proxy_upload_url = upload_url.replace(
|
||||
"https://generativelanguage.googleapis.com",
|
||||
request_host.rstrip('/')
|
||||
)
|
||||
|
||||
# 再替換 key 參數
|
||||
import re
|
||||
# 匹配 key=xxx 參數
|
||||
key_pattern = r'(\?|&)key=([^&]+)'
|
||||
match = re.search(key_pattern, proxy_upload_url)
|
||||
if match:
|
||||
# 替換為我們的 token
|
||||
proxy_upload_url = proxy_upload_url.replace(
|
||||
f"{match.group(1)}key={match.group(2)}",
|
||||
f"{match.group(1)}key={user_token}"
|
||||
)
|
||||
|
||||
logger.info(f"Replaced upload URL: {upload_url} -> {proxy_upload_url}")
|
||||
|
||||
return response_data, {
|
||||
"X-Goog-Upload-URL": proxy_upload_url,
|
||||
"X-Goog-Upload-Status": "active"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize upload: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理過期的上傳會話"""
|
||||
try:
|
||||
async with _upload_sessions_lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_keys = []
|
||||
for key, session in _upload_sessions.items():
|
||||
if now - session["created_at"] > timedelta(hours=1):
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del _upload_sessions[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired upload sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up upload sessions: {str(e)}")
|
||||
|
||||
async def get_upload_session(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""獲取上傳會話信息(支持 upload_id 或完整 URL)"""
|
||||
async with _upload_sessions_lock:
|
||||
# 先嘗試直接查找
|
||||
session = _upload_sessions.get(key)
|
||||
if session:
|
||||
logger.debug(f"Found session by direct key {key}")
|
||||
return session
|
||||
|
||||
# 如果是 URL,嘗試提取 upload_id
|
||||
if key.startswith("http"):
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(key)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
if upload_id:
|
||||
session = _upload_sessions.get(upload_id)
|
||||
if session:
|
||||
logger.debug(f"Found session by upload_id {upload_id} from URL")
|
||||
return session
|
||||
|
||||
logger.debug(f"No session found for key: {key}")
|
||||
return None
|
||||
|
||||
async def get_file(self, file_name: str, user_token: str) -> FileMetadata:
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
Args:
|
||||
file_name: 文件名称 (格式: files/{file_id})
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
FileMetadata: 文件元数据
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 检查是否过期
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
# 如果是 naive datetime,假设为 UTC
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=404, detail="File has expired")
|
||||
|
||||
# 使用原始 API key 获取文件信息
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get file: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to get file")
|
||||
|
||||
file_data = response.json()
|
||||
|
||||
# 檢查並更新文件狀態
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
if google_state != file_record.get("state", "").value if file_record.get("state") else None:
|
||||
logger.info(f"File state changed from {file_record.get('state')} to {google_state}")
|
||||
# 更新數據庫中的狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return FileMetadata(
|
||||
name=file_data["name"],
|
||||
displayName=file_data.get("displayName"),
|
||||
mimeType=file_data["mimeType"],
|
||||
sizeBytes=str(file_data["sizeBytes"]),
|
||||
createTime=file_data["createTime"],
|
||||
updateTime=file_data["updateTime"],
|
||||
expirationTime=file_data["expirationTime"],
|
||||
sha256Hash=file_data.get("sha256Hash"),
|
||||
uri=file_data["uri"],
|
||||
state=google_state
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
列出文件
|
||||
|
||||
Args:
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记
|
||||
user_token: 用户令牌(可选,如果提供则只返回该用户的文件)
|
||||
|
||||
Returns:
|
||||
ListFilesResponse: 文件列表响应
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_files called with page_size={page_size}, page_token={page_token}")
|
||||
|
||||
# 从数据库获取文件列表
|
||||
files, next_page_token = await db_services.list_file_records(
|
||||
user_token=user_token,
|
||||
page_size=page_size,
|
||||
page_token=page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Database returned {len(files)} files, next_page_token={next_page_token}")
|
||||
|
||||
# 转换为响应格式
|
||||
file_list = []
|
||||
for file_record in files:
|
||||
file_list.append(FileMetadata(
|
||||
name=file_record["name"],
|
||||
displayName=file_record.get("display_name"),
|
||||
mimeType=file_record["mime_type"],
|
||||
sizeBytes=str(file_record["size_bytes"]),
|
||||
createTime=file_record["create_time"].isoformat() + "Z",
|
||||
updateTime=file_record["update_time"].isoformat() + "Z",
|
||||
expirationTime=file_record["expiration_time"].isoformat() + "Z",
|
||||
sha256Hash=file_record.get("sha256_hash"),
|
||||
uri=file_record["uri"],
|
||||
state=file_record["state"].value if file_record.get("state") else "ACTIVE"
|
||||
))
|
||||
|
||||
response = ListFilesResponse(
|
||||
files=file_list,
|
||||
nextPageToken=next_page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Returning response with {len(response.files)} files, nextPageToken={response.nextPageToken}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def delete_file(self, file_name: str, user_token: str) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
file_name: 文件名称
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 使用原始 API key 删除文件
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 204]:
|
||||
logger.error(f"Failed to delete file: {response.status_code} - {response.text}")
|
||||
# 如果 API 删除失败,但文件已过期,仍然删除数据库记录
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to delete file")
|
||||
|
||||
# 删除数据库记录
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def check_file_state(self, file_name: str, api_key: str) -> str:
|
||||
"""
|
||||
檢查並更新文件狀態
|
||||
|
||||
Args:
|
||||
file_name: 文件名稱
|
||||
api_key: API密鑰
|
||||
|
||||
Returns:
|
||||
str: 當前狀態
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to check file state: {response.status_code}")
|
||||
return "UNKNOWN"
|
||||
|
||||
file_data = response.json()
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
|
||||
# 更新數據庫狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
return google_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check file state: {str(e)}")
|
||||
return "UNKNOWN"
|
||||
|
||||
async def cleanup_expired_files(self) -> int:
|
||||
"""
|
||||
清理过期文件
|
||||
|
||||
Returns:
|
||||
int: 清理的文件数量
|
||||
"""
|
||||
try:
|
||||
# 获取过期文件
|
||||
expired_files = await db_services.delete_expired_file_records()
|
||||
|
||||
if not expired_files:
|
||||
return 0
|
||||
|
||||
# 尝试从 Gemini API 删除文件
|
||||
for file_record in expired_files:
|
||||
try:
|
||||
api_key = file_record["api_key"]
|
||||
file_name = file_record["name"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录错误但继续处理其他文件
|
||||
logger.error(f"Failed to delete file {file_record['name']} from API: {str(e)}")
|
||||
|
||||
return len(expired_files)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired files: {str(e)}")
|
||||
return 0
|
||||
|
||||
|
||||
# 单例实例
|
||||
_files_service_instance: Optional[FilesService] = None
|
||||
|
||||
|
||||
async def get_files_service() -> FilesService:
|
||||
"""获取文件服务单例实例"""
|
||||
global _files_service_instance
|
||||
if _files_service_instance is None:
|
||||
_files_service_instance = FilesService()
|
||||
return _files_service_instance
|
||||
363
app/service/tts/native/README.md
Normal file
363
app/service/tts/native/README.md
Normal file
@@ -0,0 +1,363 @@
|
||||
# 原生Gemini TTS功能
|
||||
|
||||
这个模块为Gemini Balance项目添加了原生Gemini TTS(Text-to-Speech)功能,支持单人和多人语音合成,采用智能检测和继承模式设计,保持与原始代码的完全兼容性。
|
||||
|
||||
## 🎯 设计原则
|
||||
|
||||
- **智能检测**:自动检测所有原生Gemini TTS格式的请求(包含responseModalities和speechConfig)
|
||||
- **继承而非修改**:所有扩展都继承自原始类,不修改源码
|
||||
- **完全兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **动态模型选择**:支持用户在请求URL中指定不同的TTS模型
|
||||
- **自动回退**:原生TTS处理失败时自动回退到标准服务
|
||||
- **完整日志记录**:包含请求日志、错误日志和性能监控
|
||||
- **易于维护**:更新原始代码时不会产生冲突
|
||||
|
||||
## 📁 文件结构
|
||||
|
||||
```
|
||||
app/service/tts/
|
||||
├── tts_service.py # 原有的OpenAI兼容TTS服务
|
||||
└── native/ # 原生Gemini TTS扩展
|
||||
├── __init__.py # 模块初始化
|
||||
├── README.md # 使用说明(本文件)
|
||||
├── tts_models.py # TTS数据模型(继承自原始模型)
|
||||
├── tts_response_handler.py # TTS响应处理器(继承自原始处理器)
|
||||
├── tts_chat_service.py # TTS聊天服务(继承自原始服务)
|
||||
└── tts_routes.py # TTS路由扩展和依赖注入
|
||||
```
|
||||
|
||||
## 🚀 原生Gemini TTS功能
|
||||
|
||||
### 智能检测机制(当前实现)
|
||||
|
||||
原生Gemini TTS功能通过智能检测自动启用,无需任何配置:
|
||||
|
||||
1. **自动启用**:
|
||||
```bash
|
||||
# 直接启动服务,原生TTS功能自动可用
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
2. **无需配置**:
|
||||
- 不需要环境变量
|
||||
- 不需要修改配置文件
|
||||
- 完全基于请求内容智能判断
|
||||
|
||||
### 工作原理
|
||||
|
||||
系统会智能检测请求内容:
|
||||
- **原生TTS请求**:包含 `responseModalities: ["AUDIO"]` 和 `speechConfig` → 使用TTS增强服务
|
||||
- **单人TTS**:包含 `voiceConfig.prebuiltVoiceConfig`
|
||||
- **多人TTS**:包含 `multiSpeakerVoiceConfig`
|
||||
- **普通请求**:非TTS模型 → 使用原有Gemini聊天服务
|
||||
|
||||
```python
|
||||
# app/router/gemini_routes.py 中的智能检测逻辑
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
# 使用TTS增强服务
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
return await tts_service.generate_content(...)
|
||||
# 否则使用原有服务
|
||||
```
|
||||
|
||||
## 📝 使用示例
|
||||
|
||||
### 1. 原生Gemini单人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `voiceConfig.prebuiltVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Hello, this is a single speaker test."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. 原生Gemini多人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `multiSpeakerVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Alice: Hello everyone, welcome to our show today.\nBob: Hi Alice, and hello to all our listeners! Today we are talking about AI development."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "Alice",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Puck"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"speaker": "Bob",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. OpenAI兼容TTS请求(使用原有服务)
|
||||
|
||||
OpenAI兼容格式的TTS请求使用不同的API路径,不受本模块影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1/audio/speech" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-token" \
|
||||
-d '{
|
||||
"model": "tts-1",
|
||||
"input": "这是一个OpenAI兼容格式的TTS测试。",
|
||||
"voice": "alloy"
|
||||
}' \
|
||||
--output openai_tts_test.wav
|
||||
```
|
||||
|
||||
**注意**:OpenAI兼容TTS请求:
|
||||
- 使用路径:`/v1/audio/speech`
|
||||
- 使用Authorization头而不是x-goog-api-key
|
||||
- 返回音频文件而不是JSON响应
|
||||
- 不受本模块的TTS增强服务影响
|
||||
|
||||
### 普通文本生成(使用原有服务)
|
||||
|
||||
非TTS模型的请求会使用原有的Gemini聊天服务,完全不受影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "请简单介绍一下人工智能的发展历程。"
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": 200,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 🔧 技术实现
|
||||
|
||||
### 继承关系
|
||||
|
||||
```
|
||||
GeminiChatService
|
||||
↓ (继承)
|
||||
TTSGeminiChatService
|
||||
├── 重写 generate_content() 方法
|
||||
├── 添加 _handle_tts_request() 方法
|
||||
└── 集成完整的日志记录功能
|
||||
|
||||
GeminiResponseHandler
|
||||
↓ (继承)
|
||||
TTSResponseHandler
|
||||
└── 重写 handle_response() 方法
|
||||
|
||||
GenerationConfig (Pydantic模型)
|
||||
↓ (扩展)
|
||||
TTSGenerationConfig
|
||||
├── responseModalities: List[str]
|
||||
└── speechConfig: Dict[str, Any]
|
||||
```
|
||||
|
||||
### 工作流程
|
||||
|
||||
1. **请求接收**:系统接收到API请求
|
||||
2. **智能检测**:
|
||||
- 检查模型名称是否包含 "tts"
|
||||
- 如果是TTS模型,从 `request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig`
|
||||
3. **服务选择**:
|
||||
- **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务
|
||||
- **普通请求**:使用原有 `GeminiChatService`
|
||||
4. **请求处理**:
|
||||
- **原生TTS**:使用 `_handle_tts_request()` 特殊处理
|
||||
- **其他请求**:使用标准 `generate_content()` 方法
|
||||
5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段(`responseModalities`, `speechConfig`)
|
||||
6. **API调用**:构建优化的payload并调用Gemini API
|
||||
7. **自动回退**:如果原生TTS处理失败,自动回退到标准服务
|
||||
8. **响应处理**:
|
||||
- **TTS响应**:检测音频数据,直接返回原始响应
|
||||
- **普通响应**:使用标准处理方法
|
||||
9. **日志记录**:记录请求时间、成功状态、错误信息到数据库
|
||||
|
||||
## 📊 功能特性
|
||||
|
||||
### ✅ 已实现功能
|
||||
|
||||
- **智能原生TTS支持**:支持单人和多人语音合成
|
||||
- **单人TTS**:支持 `voiceConfig.prebuiltVoiceConfig` 配置
|
||||
- **多人TTS**:支持 `multiSpeakerVoiceConfig` 配置
|
||||
- **智能检测机制**:自动检测所有原生Gemini TTS格式的请求
|
||||
- **动态模型选择**:支持用户在URL中指定不同TTS模型
|
||||
- **完全向后兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **自动回退机制**:原生TTS处理失败时自动使用标准服务
|
||||
- **完整日志记录**:请求日志、错误日志、性能监控
|
||||
- **API配额管理**:自动重试和密钥轮换
|
||||
- **零配置启用**:无需环境变量或配置文件修改
|
||||
- **错误处理**:完整的异常捕获和错误记录
|
||||
|
||||
### 🎵 支持的语音配置
|
||||
|
||||
#### 单人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 多人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "角色名称",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### API要求
|
||||
- 确保API密钥有TTS权限
|
||||
- TTS功能需要 `gemini-2.5-flash-preview-tts` 模型
|
||||
- 注意API配额限制(免费版每天15次)
|
||||
|
||||
### 性能考虑
|
||||
- TTS响应通常比文本响应更大(音频数据)
|
||||
- 建议监控API调用频率和成功率
|
||||
- 扩展功能不影响原始功能的性能和稳定性
|
||||
|
||||
### 部署建议
|
||||
- 生产环境建议先测试普通功能
|
||||
- 逐步启用TTS功能并监控日志
|
||||
- 定期检查API配额使用情况
|
||||
|
||||
## 📈 监控和调试
|
||||
|
||||
### 日志查看
|
||||
- **服务器日志**:查看TTS请求处理过程
|
||||
- **管理界面**:在"API 调用详情"中查看请求记录
|
||||
- **错误日志**:查看失败请求的详细信息
|
||||
|
||||
### 调试技巧
|
||||
```bash
|
||||
# 启用详细日志
|
||||
export LOG_LEVEL=DEBUG
|
||||
|
||||
# 查看实时日志
|
||||
tail -f logs/app.log
|
||||
|
||||
# 多人TTS功能无需配置,自动启用
|
||||
# 可通过请求内容智能检测
|
||||
```
|
||||
|
||||
## 🔄 TTS系统对比
|
||||
|
||||
项目中现在有三套TTS系统,各自服务不同的用途:
|
||||
|
||||
| TTS类型 | 路径 | 模型选择 | 语音配置 | 使用场景 | 我们的影响 |
|
||||
|---------|------|----------|----------|----------|------------|
|
||||
| **OpenAI兼容TTS** | `/v1/audio/speech` | 固定配置文件 | 单人语音 | OpenAI API兼容 | ✅ 无影响 |
|
||||
| **Gemini单人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 单人语音 | 原生Gemini TTS | ✅ 我们的增强 |
|
||||
| **Gemini多人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 多人语音 | 对话场景 | ✅ 我们的增强 |
|
||||
|
||||
### 智能路由机制
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[API请求] --> B{路径检查}
|
||||
B -->|/v1/audio/speech| C[OpenAI兼容TTS服务]
|
||||
B -->|/v1beta/models/{model}:generateContent| D{模型名包含'tts'?}
|
||||
D -->|否| E[标准Gemini聊天服务]
|
||||
D -->|是| F{包含responseModalities和speechConfig?}
|
||||
F -->|否| G[标准Gemini聊天服务]
|
||||
F -->|是| H[原生TTS增强服务]
|
||||
H --> I{处理成功?}
|
||||
I -->|是| J[返回原生TTS响应]
|
||||
I -->|否| K[自动回退到标准服务]
|
||||
C --> L[完成]
|
||||
E --> L
|
||||
G --> L
|
||||
J --> L
|
||||
K --> L
|
||||
```
|
||||
|
||||
## 🎉 成功案例
|
||||
|
||||
基于智能检测的原生Gemini TTS解决方案已经成功实现:
|
||||
|
||||
- ✅ **零配置启用**:无需任何环境变量或配置修改
|
||||
- ✅ **智能检测**:自动检测所有原生Gemini TTS格式的请求
|
||||
- ✅ **完全向后兼容**:所有原有TTS功能零影响
|
||||
- ✅ **动态模型选择**:支持用户指定不同TTS模型
|
||||
- ✅ **自动回退机制**:处理失败时自动使用标准服务
|
||||
- ✅ **单人和多人语音合成**:支持所有原生Gemini TTS场景
|
||||
- ✅ **完整日志记录**:可在管理界面查看所有请求
|
||||
- ✅ **错误处理完善**:API配额和重试机制
|
||||
- ✅ **易于维护**:更新原始代码无冲突
|
||||
|
||||
这个实现展示了如何在不修改原始代码的情况下,优雅地扩展复杂系统的功能,同时保持完美的向后兼容性。
|
||||
19
app/service/tts/native/__init__.py
Normal file
19
app/service/tts/native/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
原生Gemini TTS功能模块
|
||||
Native Gemini TTS functionality for both single and multi-speaker scenarios
|
||||
"""
|
||||
|
||||
from .tts_chat_service import TTSGeminiChatService
|
||||
from .tts_models import TTSGenerationConfig, MultiSpeakerVoiceConfig, SpeechConfig, TTSRequest
|
||||
from .tts_response_handler import TTSResponseHandler
|
||||
from .tts_routes import get_tts_chat_service
|
||||
|
||||
__all__ = [
|
||||
"TTSGeminiChatService",
|
||||
"TTSGenerationConfig",
|
||||
"MultiSpeakerVoiceConfig",
|
||||
"SpeechConfig",
|
||||
"TTSRequest",
|
||||
"TTSResponseHandler",
|
||||
"get_tts_chat_service"
|
||||
]
|
||||
151
app/service/tts/native/tts_chat_service.py
Normal file
151
app/service/tts/native/tts_chat_service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
原生Gemini TTS聊天服务扩展
|
||||
继承自原始聊天服务,添加原生Gemini TTS支持(单人和多人),保持向后兼容
|
||||
"""
|
||||
|
||||
import time
|
||||
import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_response_handler import TTSResponseHandler
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.database.services import add_request_log, add_error_log
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSGeminiChatService(GeminiChatService):
|
||||
"""
|
||||
支持TTS的Gemini聊天服务
|
||||
继承自原始的GeminiChatService,添加TTS功能
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager):
|
||||
"""
|
||||
初始化TTS聊天服务
|
||||
"""
|
||||
super().__init__(base_url, key_manager)
|
||||
# 使用TTS响应处理器替换原始处理器
|
||||
self.response_handler = TTSResponseHandler()
|
||||
logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support")
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成内容,支持TTS
|
||||
"""
|
||||
try:
|
||||
# 添加调试日志
|
||||
logger.info(f"TTS request model: {model}")
|
||||
logger.info(f"TTS request generationConfig: {request.generationConfig}")
|
||||
|
||||
# 检查是否是TTS模型,如果是,需要特殊处理
|
||||
if "tts" in model.lower():
|
||||
logger.info("Detected TTS model, applying TTS-specific processing")
|
||||
# 对于TTS模型,我们需要确保正确的字段被传递
|
||||
response = await self._handle_tts_request(model, request, api_key)
|
||||
return response
|
||||
else:
|
||||
# 对于非TTS模型,使用父类的方法
|
||||
response = await super().generate_content(model, request, api_key)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"TTS API call failed with error: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
处理TTS特定的请求,包含完整的日志记录功能
|
||||
"""
|
||||
# 记录开始时间和请求时间
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
# 构建TTS专用的payload - 不包含tools和safetySettings
|
||||
from app.service.chat.gemini_chat_service import _filter_empty_parts
|
||||
|
||||
request_dict = request.model_dump(by_alias=True, exclude_none=False)
|
||||
|
||||
# 构建TTS专用的简化payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
# 从request.generationConfig直接获取TTS相关字段
|
||||
if request.generationConfig:
|
||||
# 添加TTS特定字段
|
||||
if request.generationConfig.responseModalities:
|
||||
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
|
||||
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
|
||||
|
||||
if request.generationConfig.speechConfig:
|
||||
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
|
||||
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
|
||||
else:
|
||||
logger.warning("No generationConfig found in request, TTS fields may be missing")
|
||||
|
||||
logger.info(f"TTS payload before API call: {payload}")
|
||||
|
||||
# 调用API
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
|
||||
# 如果到达这里,说明API调用成功
|
||||
is_success = True
|
||||
status_code = 200
|
||||
|
||||
# 使用TTS响应处理器处理响应
|
||||
return self.response_handler.handle_response(response, model, False, None)
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
is_success = False
|
||||
error_msg = str(e)
|
||||
|
||||
# 尝试从错误消息中提取状态码
|
||||
import re
|
||||
match = re.search(r"status code (\d+)", error_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# 添加错误日志
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="tts-api-error",
|
||||
error_log=error_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.model_dump(by_alias=True, exclude_none=False)
|
||||
)
|
||||
|
||||
logger.error(f"TTS API call failed: {error_msg}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 记录请求日志
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
37
app/service/tts/native/tts_config.py
Normal file
37
app/service/tts/native/tts_config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
TTS扩展配置
|
||||
控制是否启用TTS功能
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
class TTSConfig:
|
||||
"""TTS配置管理"""
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled() -> bool:
|
||||
"""
|
||||
检查是否启用TTS功能
|
||||
通过环境变量 ENABLE_TTS 控制,默认为 False
|
||||
"""
|
||||
return os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
@staticmethod
|
||||
def get_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""
|
||||
工厂方法:根据配置返回合适的聊天服务
|
||||
"""
|
||||
if TTSConfig.is_tts_enabled():
|
||||
return TTSGeminiChatService(base_url, key_manager)
|
||||
else:
|
||||
return GeminiChatService(base_url, key_manager)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""创建聊天服务实例"""
|
||||
return TTSConfig.get_chat_service(base_url, key_manager)
|
||||
36
app/service/tts/native/tts_models.py
Normal file
36
app/service/tts/native/tts_models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
原生Gemini TTS扩展数据模型
|
||||
继承自原始模型,添加原生Gemini TTS相关字段,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.domain.gemini_models import GenerationConfig as BaseGenerationConfig
|
||||
|
||||
|
||||
class TTSGenerationConfig(BaseGenerationConfig):
|
||||
"""
|
||||
支持TTS的生成配置类
|
||||
继承自原始的GenerationConfig,添加TTS相关字段
|
||||
"""
|
||||
# TTS 相关配置
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MultiSpeakerVoiceConfig(BaseModel):
|
||||
"""多人语音配置"""
|
||||
speakerVoiceConfigs: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SpeechConfig(BaseModel):
|
||||
"""语音配置"""
|
||||
multiSpeakerVoiceConfig: Optional[MultiSpeakerVoiceConfig] = None
|
||||
voiceConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
"""TTS请求模型"""
|
||||
contents: List[Dict[str, Any]]
|
||||
generationConfig: TTSGenerationConfig
|
||||
53
app/service/tts/native/tts_response_handler.py
Normal file
53
app/service/tts/native/tts_response_handler.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
原生Gemini TTS响应处理器扩展
|
||||
继承自原始响应处理器,添加原生Gemini TTS支持,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.log.logger import get_gemini_logger
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSResponseHandler(GeminiResponseHandler):
|
||||
"""
|
||||
支持TTS的响应处理器
|
||||
继承自原始的GeminiResponseHandler,添加TTS响应处理
|
||||
"""
|
||||
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理响应,支持TTS音频数据
|
||||
"""
|
||||
# 检查是否是TTS响应(包含音频数据)
|
||||
if self._is_tts_response(response):
|
||||
logger.info("Detected TTS response with audio data, returning original response")
|
||||
return response
|
||||
|
||||
# 对于非TTS响应,使用父类的处理方法
|
||||
return super().handle_response(response, model, stream, usage_metadata)
|
||||
|
||||
def _is_tts_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查是否是TTS响应
|
||||
"""
|
||||
try:
|
||||
if (response.get("candidates") and
|
||||
len(response["candidates"]) > 0 and
|
||||
response["candidates"][0].get("content") and
|
||||
response["candidates"][0]["content"].get("parts") and
|
||||
len(response["candidates"][0]["content"]["parts"]) > 0):
|
||||
|
||||
parts = response["candidates"][0]["content"]["parts"]
|
||||
for part in parts:
|
||||
if "inlineData" in part:
|
||||
mime_type = part["inlineData"].get("mimeType", "")
|
||||
if mime_type.startswith("audio/"):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking TTS response: {e}")
|
||||
return False
|
||||
24
app/service/tts/native/tts_routes.py
Normal file
24
app/service/tts/native/tts_routes.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
TTS路由扩展
|
||||
提供原生Gemini TTS增强服务,支持单人和多人语音
|
||||
"""
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from app.config.config import settings
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
"""获取密钥管理器实例"""
|
||||
return get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_tts_chat_service(key_manager: KeyManager = Depends(get_key_manager)) -> TTSGeminiChatService:
|
||||
"""
|
||||
获取原生Gemini TTS增强聊天服务实例,支持单人和多人语音
|
||||
"""
|
||||
return TTSGeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@@ -1618,7 +1618,7 @@ endblock %} {% block head_extra_styles %}
|
||||
<option value="Leda">Leda (年轻)</option>
|
||||
<option value="Orus">Orus (坚定)</option>
|
||||
<option value="Aoede">Aoede (轻松)</option>
|
||||
<option value="Callirhoe">Callirhoe (随和)</option>
|
||||
<option value="Callirrhoe">Callirrhoe (随和)</option>
|
||||
<option value="Autonoe">Autonoe (明亮)</option>
|
||||
<option value="Enceladus">Enceladus (呼吸感)</option>
|
||||
<option value="Iapetus">Iapetus (清晰)</option>
|
||||
|
||||
Reference in New Issue
Block a user