feat(model): 增加模型管理和测试功能

- 新增模型删除功能
- 实现模型测试连接功能
- 优化模型选择器组件
- 更新模型相关API和数据库操作
This commit is contained in:
JefferyHcool
2025-05-26 23:16:19 +08:00
parent ee9f6ed80c
commit 9b298d3094
10 changed files with 158 additions and 37 deletions

View File

@@ -0,0 +1,5 @@
# exceptions.py
class ConnectionTestError(Exception):
def __init__(self, message: str):
super().__init__(message)
self.message = message

View File

@@ -2,6 +2,7 @@ from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from app.exceptions.provider import ConnectionTestError
from app.models.model_config import ModelConfig
from app.services.model import ModelService
from app.utils.response import ResponseWrapper as R
@@ -18,9 +19,7 @@ class ProviderRequest(BaseModel):
type: str
class TestRequest(BaseModel):
api_key: str
base_url:str
id: str
class ProviderUpdateRequest(BaseModel):
id: str
name: Optional[str] = None
@@ -33,14 +32,14 @@ class ProviderUpdateRequest(BaseModel):
@router.post("/add_provider")
def add_provider(data: ProviderRequest):
try:
ProviderService.add_provider(
res = ProviderService.add_provider(
name=data.name,
api_key=data.api_key,
base_url=data.base_url,
logo=data.logo,
type_=data.type
)
return R.success(msg='添加模型供应商成功')
return R.success(msg='添加模型供应商成功',data=res)
except Exception as e:
return R.error(msg=e)
@@ -78,23 +77,20 @@ def update_provider(data: ProviderUpdateRequest):
):
return R.error(msg='请至少填写一个参数')
ProviderService.update_provider(
provider_id =ProviderService.update_provider(
id=data.id,
data=dict(data)
)
return R.success(msg='更新模型供应商成功')
return R.success(msg='更新模型供应商成功',data={'id': provider_id})
except Exception as e:
print(e)
return R.error(msg=e)
return R.error(msg=str(e))
@router.post('/connect_test')
def gpt_connect_test(data:TestRequest):
def gpt_connect_test(data: TestRequest):
try:
res= ModelService().connect_test(data.api_key,data.base_url)
if not res:
return R.error(msg='连接失败')
ModelService().connect_test(data.id)
return R.success(msg='连接成功')
except Exception as e:
print(e)
return R.error(msg=e)
print("捕获到异常类型:", type(e))
return R.error(msg=str(e))

View File

@@ -1,5 +1,6 @@
from app.db.model_dao import insert_model, get_all_models
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
from app.db.provider_dao import get_enabled_providers
from app.exceptions.provider import ConnectionTestError
from app.gpt.gpt_factory import GPTFactory
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.models.model_config import ModelConfig
@@ -70,6 +71,13 @@ class ModelService:
})
return formatted
@staticmethod
def get_enabled_models_by_provider( provider_id: str|int,):
from app.db.model_dao import get_models_by_provider
all_models = get_models_by_provider(provider_id)
enabled_models = all_models
return enabled_models
@staticmethod
def get_all_models_by_id(provider_id: str, verbose: bool = False):
try:
provider = ProviderService.get_provider_by_id(provider_id)
@@ -86,13 +94,35 @@ class ModelService:
print(f"[{provider_id}] 获取模型失败: {e}")
return []
@staticmethod
def connect_test(api_key: str, base_url: str) -> bool:
def connect_test(id: str) -> bool:
try:
return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url)
except Exception as e:
print(f"连接测试失败:{e}")
return False
provider = ProviderService.get_provider_by_id(id)
if provider:
if not provider.get('api_key'):
raise ConnectionTestError(f"供应商信息未找到,请先保存重试")
result = OpenAICompatibleProvider.test_connection(
api_key=provider.get('api_key'),
base_url=provider.get('base_url')
)
if result:
return True
else:
raise ConnectionTestError("请检查API Key 和 API 地址是否正确")
raise ConnectionTestError("供应商信息未找到,请先保存重试")
except Exception as e:
# 抛出业务异常,交由 Controller 处理
raise ConnectionTestError(f"{str(e)}") from e
@staticmethod
def delete_model_by_id( model_id: int) -> bool:
try:
delete_model(model_id)
return True
except Exception as e:
print(f"[{model_id}] <UNK>: {e}")
return False
@staticmethod
def add_new_model(provider_id: int, model_name: str) -> bool:
try:
@@ -102,6 +132,12 @@ class ModelService:
print(f"供应商ID {provider_id} 不存在,无法添加模型")
return False
# 查询是否已存在同名模型
existing = get_model_by_provider_and_name(provider_id, model_name)
if existing:
print(f"模型 {model_name} 已存在于供应商ID {provider_id} 下,跳过插入")
return False
# 插入模型
insert_model(provider_id=provider_id, model_name=model_name)
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")

View File

@@ -82,15 +82,17 @@ class ProviderService:
# all_models.extend(provider['models'])
@staticmethod
def update_provider(id: str, data: dict):
def update_provider(id: str, data: dict)->str | None:
try:
# 过滤掉空值
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
print('更新模型供应商',filtered_data)
return update_provider(id, **filtered_data)
update_provider(id, **filtered_data)
return id
except Exception as e:
print('更新模型供应商失败:',e)
return None
@staticmethod
def delete_provider(id: str):