Files
BiliNote/backend/app/services/provider.py
techotaku39 4425239717 fix(backend): 防御 API Key 掩码污染并修复 EXE 版 .env 加载路径
- provider.py: 更新供应商时,若 api_key 包含 '*'(掩码字符),
  跳过该字段,防止前端展示用的 mask_key() 值被误写入数据库。

- ffmpeg_helper.py: load_dotenv() 默认只从 CWD 查找 .env,
  PyInstaller 打包后 CWD 为 EXE 目录,导致 _internal/.env 被忽略。
  改为遍历多个候选路径(CWD、脚本目录、项目根目录、_internal/),
  确保源码和打包两种场景都能正确加载环境变量。
2026-05-23 22:49:56 +08:00

152 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi.encoders import jsonable_encoder
from kombu import uuid
from app.db.models.providers import Provider
from app.db.provider_dao import (
insert_provider,
get_all_providers,
get_provider_by_name,
get_provider_by_id,
update_provider,
delete_provider, get_enabled_providers,
)
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
class ProviderService:
@staticmethod
def serialize_provider(row: Provider) -> dict:
if not row:
return None
row = ProviderService.provider_to_dict(row)
return {
"id": row.get("id"),
"name": row.get("name"),
"logo": row.get("logo"),
"type":row.get("type"),
"enabled": row.get("enabled"),
"base_url": row.get("base_url"),
"api_key": row.get("api_key"),
"created_at": jsonable_encoder(row.get("created_at")),
# "name": row[1],
# "logo": row[2],
# "type": row[3],
# "api_key": row[4],
# "base_url": row[5],
# "enabled": row[6],
# "created_at": row[7],
}
@staticmethod
def serialize_provider_safe(row: Provider) -> dict:
if not row:
return None
row = ProviderService.provider_to_dict(row)
return {
"id": row.get("id"),
"name": row.get("name"),
"logo": row.get("logo"),
"type":row.get("type"),
"enabled": row.get("enabled"),
"base_url": row.get("base_url"),
"api_key": ProviderService.mask_key(row.get("api_key")),
"created_at": jsonable_encoder(row.get("created_at")),
# "id": row[0],
# "name": row[1],
# "logo": row[2],
# "type": row[3],
# "api_key": ProviderService.mask_key(row[4]),
# "base_url": row[5],
# "enabled": row[6],
# "created_at": row[7],
}
@staticmethod
def mask_key(key: str) -> str:
if not key or len(key) < 8:
return '*' * len(key)
return key[:4] + '*' * (len(key) - 8) + key[-4:]
@staticmethod
def add_provider( name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
try:
# 内置供应商type='built-in')只能由 seed 流程写入API 创建一律落到 'custom'
# 否则历史上出现过批量伪内置脏数据
if type_ != 'custom':
type_ = 'custom'
existing = get_provider_by_name(name)
if existing is not None:
raise ValueError(f'供应商名称已存在: {name}')
id = uuid().lower()
logo = 'custom'
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
except Exception as e:
print('创建模式失败',e)
raise
@staticmethod
def provider_to_dict(p: Provider):
return {
"id": p.id,
"name": p.name,
"logo": p.logo,
"type": p.type,
"api_key": p.api_key,
"base_url": p.base_url,
"enabled": p.enabled,
"created_at": p.created_at,
}
@staticmethod
def get_all_providers():
rows = get_all_providers()
if rows is None:
return []
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
@staticmethod
def get_all_providers_safe():
rows = get_all_providers()
return [ProviderService.serialize_provider(row) for row in rows] if (rows) else []
@staticmethod
def get_provider_by_name(name: str):
row = get_provider_by_name(name)
return ProviderService.serialize_provider(row)
@staticmethod
def get_provider_by_id(id: str): # 已改为 str 类型
row = get_provider_by_id(id)
return ProviderService.serialize_provider(row)
@staticmethod
def get_provider_by_id_safe(id: str): # 已改为 str 类型
row = get_provider_by_id(id)
return ProviderService.serialize_provider_safe(row)
# all_models.extend(provider['models'])
@staticmethod
def update_provider(id: str, data: dict)->str | None:
try:
# 过滤掉空值
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
# 防御掩码污染:前端展示时 api_key 被 mask_key() 处理过(如 a92f****...2d3a
# 如果用户未重新输入直接保存,带星号的值不应覆盖原 key。
if 'api_key' in filtered_data and '*' in str(filtered_data.get('api_key', '')):
filtered_data.pop('api_key')
print('更新模型供应商',filtered_data)
update_provider(id, **filtered_data)
# 获取更新后的供应商信息
updated_provider = get_provider_by_id(id)
return {
'id': id,
'enabled': updated_provider.enabled,
}
except Exception as e:
print('更新模型供应商失败:',e)
return None
@staticmethod
def delete_provider(id: str):
return delete_provider(id)