mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-26 02:01:38 +08:00
feat(model): 增加模型管理和测试功能
- 新增模型删除功能 - 实现模型测试连接功能 - 优化模型选择器组件 - 更新模型相关API和数据库操作
This commit is contained in:
@@ -16,7 +16,7 @@ import { useParams, useNavigate } from 'react-router-dom'
|
||||
import { useProviderStore } from '@/store/providerStore'
|
||||
import { useEffect, useState } from 'react'
|
||||
import toast from 'react-hot-toast'
|
||||
import { testConnection, fetchModels } from '@/services/model.ts'
|
||||
import { testConnection, fetchModels, deleteModelById } from '@/services/model.ts'
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@@ -26,6 +26,9 @@ import {
|
||||
} from '@/components/ui/select.tsx' // ⚡新增 fetchModels
|
||||
import { ModelSelector } from '@/components/Form/modelForm/ModelSelector.tsx'
|
||||
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert.tsx'
|
||||
import { Tags } from 'lucide-react'
|
||||
import { Tag } from 'antd'
|
||||
import { useModelStore } from '@/store/modelStore'
|
||||
|
||||
// ✅ Provider表单schema
|
||||
const ProviderSchema = z.object({
|
||||
@@ -52,7 +55,7 @@ interface IModel {
|
||||
root: string
|
||||
}
|
||||
const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
const { id } = useParams()
|
||||
let { id } = useParams()
|
||||
const navigate = useNavigate()
|
||||
const isEditMode = !isCreate
|
||||
|
||||
@@ -60,12 +63,16 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
const loadProviderById = useProviderStore(state => state.loadProviderById)
|
||||
const updateProvider = useProviderStore(state => state.updateProvider)
|
||||
const addNewProvider = useProviderStore(state => state.addNewProvider)
|
||||
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [testing, setTesting] = useState(false)
|
||||
const [isBuiltIn, setIsBuiltIn] = useState(false)
|
||||
const loadModelsById= useModelStore(state => state.loadModelsById)
|
||||
const [modelOptions, setModelOptions] = useState<IModel[]>([]) // ⚡新增,保存模型列表
|
||||
const [models, setModels]= useState([])
|
||||
const [modelLoading, setModelLoading] = useState(false)
|
||||
const randomColor = ()=>{
|
||||
return '#' + Math.floor(Math.random() * 16777215).toString(16)
|
||||
}
|
||||
|
||||
const [search, setSearch] = useState('')
|
||||
const providerForm = useForm<ProviderFormValues>({
|
||||
@@ -91,8 +98,10 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
const load = async () => {
|
||||
if (isEditMode) {
|
||||
|
||||
const data = await loadProviderById(id!)
|
||||
providerForm.reset(data)
|
||||
setIsBuiltIn(data.type === 'built-in')
|
||||
@@ -105,11 +114,30 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
})
|
||||
setIsBuiltIn(false)
|
||||
}
|
||||
const models = await loadModelsById(id!)
|
||||
if(models){
|
||||
console.log('🔧 模型列表:', models)
|
||||
setModels(models)
|
||||
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
load()
|
||||
}, [id])
|
||||
const handelDelete=async (modelId)=>{
|
||||
if (!window.confirm('确定要删除这个模型吗?')) return
|
||||
|
||||
try {
|
||||
const res = await deleteModelById(modelId)
|
||||
if (res.data.code === 0) {
|
||||
toast.success('删除成功')
|
||||
} else {
|
||||
toast.error(res.data.msg || '删除失败')
|
||||
}
|
||||
} catch (e) {
|
||||
toast.error('删除异常')
|
||||
}
|
||||
}
|
||||
// 测试连通性
|
||||
const handleTest = async () => {
|
||||
const values = providerForm.getValues()
|
||||
@@ -118,10 +146,13 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
return
|
||||
}
|
||||
try {
|
||||
if (!id){
|
||||
toast.error('请先保存供应商信息')
|
||||
return
|
||||
}
|
||||
setTesting(true)
|
||||
const data = await testConnection({
|
||||
api_key: values.apiKey,
|
||||
base_url: values.baseUrl,
|
||||
id
|
||||
})
|
||||
if (data.data.code === 0) {
|
||||
toast.success('测试连通性成功 🎉')
|
||||
@@ -162,18 +193,21 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
// 保存Provider信息
|
||||
const onProviderSubmit = async (values: ProviderFormValues) => {
|
||||
if (isEditMode) {
|
||||
updateProvider({ ...values, id: id! })
|
||||
await updateProvider({ ...values, id: id! })
|
||||
toast.success('更新供应商成功')
|
||||
} else {
|
||||
addNewProvider({ ...values })
|
||||
id = await addNewProvider({ ...values })
|
||||
|
||||
toast.success('新增供应商成功')
|
||||
}
|
||||
// 刷新页面
|
||||
|
||||
}
|
||||
|
||||
// 保存Model信息
|
||||
const onModelSubmit = async (values: ModelFormValues) => {
|
||||
console.log('🔧 选择的模型:', values.modelName)
|
||||
toast.success(`保存模型: ${values.modelName}`)
|
||||
await loadModelsById(id!)
|
||||
}
|
||||
|
||||
if (loading) return <div className="p-4">加载中...</div>
|
||||
@@ -267,6 +301,32 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
</div>
|
||||
<ModelSelector providerId={id!} />
|
||||
|
||||
{/*<datalist id="model-options">*/}
|
||||
{/* {modelOptions.map(model => (*/}
|
||||
{/* <option key={model.id + '1'} value={model.id} />*/}
|
||||
{/* ))}*/}
|
||||
{/*</datalist>*/}
|
||||
</div>
|
||||
<div className="flex flex-col gap-2">
|
||||
<span className="font-bold">已启用模型</span>
|
||||
<div className={'flex flex-wrap gap-2 rounded p-2.5'}>
|
||||
{
|
||||
models && models.map(model => {
|
||||
return (
|
||||
<>
|
||||
<Tag onClose={()=>{
|
||||
handelDelete(model.id)
|
||||
}} key={model.id} closable color={'blue'}>
|
||||
{model.model_name}
|
||||
</Tag></>
|
||||
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
</div>
|
||||
{/*<ModelSelector providerId={id!} />*/}
|
||||
|
||||
{/*<datalist id="model-options">*/}
|
||||
{/* {modelOptions.map(model => (*/}
|
||||
{/* <option key={model.id + '1'} value={model.id} />*/}
|
||||
|
||||
@@ -76,8 +76,8 @@ export function ModelSelector({ providerId }: ModelSelectorProps) {
|
||||
className="h-8"
|
||||
/>
|
||||
</div>
|
||||
{filteredModels.map(model => (
|
||||
<SelectItem key={model.id} value={model.id}>
|
||||
{filteredModels.map((model, index) => (
|
||||
<SelectItem key={`${model.id}-${index}`} value={model.id}>
|
||||
{model.id}
|
||||
</SelectItem>
|
||||
))}
|
||||
|
||||
@@ -26,7 +26,7 @@ export default function AboutPage() {
|
||||
height={50}
|
||||
className="rounded-lg"
|
||||
/>
|
||||
<h1 className="text-4xl font-bold">BiliNote v1.5.0</h1>
|
||||
<h1 className="text-4xl font-bold">BiliNote v1.7.2</h1>
|
||||
</div>
|
||||
<p className="text-muted-foreground mb-6 text-xl italic">
|
||||
AI 视频笔记生成工具 让 AI 为你的视频做笔记
|
||||
|
||||
@@ -18,10 +18,14 @@ export const testConnection = async (data: any) => {
|
||||
return await request.post('/connect_test', data)
|
||||
}
|
||||
|
||||
export const fetchModels = async (providerId: any) => {
|
||||
export const fetchModels = async (providerId: string) => {
|
||||
return await request.get('/model_list/' + providerId)
|
||||
}
|
||||
|
||||
export const fetchEnableModelById = async (id: string) => {
|
||||
return await request.get('/model_enable/' + id)
|
||||
}
|
||||
|
||||
export async function addModel(data: { provider_id: string; model_name: string }) {
|
||||
return request.post('/models', data)
|
||||
}
|
||||
@@ -29,3 +33,7 @@ export async function addModel(data: { provider_id: string; model_name: string }
|
||||
export const fetchEnableModels = async () => {
|
||||
return await request.get('/model_list')
|
||||
}
|
||||
|
||||
export const deleteModelById = async (modelId: number) => {
|
||||
return await request.get(`/models/delete/${modelId}`)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { create } from 'zustand'
|
||||
import { devtools } from 'zustand/middleware'
|
||||
import { fetchModels, addModel, fetchEnableModels } from '@/services/model.ts'
|
||||
import { fetchModels, addModel, fetchEnableModels, fetchEnableModelById, deleteModelById } from '@/services/model.ts'
|
||||
|
||||
interface IModel {
|
||||
id: string
|
||||
@@ -18,8 +18,10 @@ interface ModelStore {
|
||||
selectedModel: string
|
||||
loadModels: (providerId: string) => Promise<void>
|
||||
loadEnabledModels: () => Promise<void>
|
||||
loadModelsById : (providerId: string) => Promise<void>
|
||||
addNewModel: (providerId: string, modelId: string) => Promise<void>
|
||||
setSelectedModel: (modelId: string) => void
|
||||
deleteModel: (modelId: number) => Promise<void>
|
||||
clearModels: () => void
|
||||
}
|
||||
|
||||
@@ -45,6 +47,10 @@ export const useModelStore = create<ModelStore>()(
|
||||
console.error('加载模型出错', error)
|
||||
}
|
||||
},
|
||||
|
||||
deleteModel: async (modelId: number) => {
|
||||
await deleteModelById( modelId)
|
||||
},
|
||||
// 加载模型列表
|
||||
loadModels: async (providerId: string) => {
|
||||
try {
|
||||
@@ -65,7 +71,13 @@ export const useModelStore = create<ModelStore>()(
|
||||
set({ loading: false })
|
||||
}
|
||||
},
|
||||
|
||||
loadModelsById: async (providerId: string)=>{
|
||||
const models = await fetchEnableModelById(providerId)
|
||||
if (models.data.code === 0) {
|
||||
console.log('模型列表加载成功:', models.data)
|
||||
return models.data.data
|
||||
}
|
||||
},
|
||||
// 新增模型
|
||||
addNewModel: async (providerId: string, modelId: string) => {
|
||||
try {
|
||||
|
||||
@@ -66,7 +66,9 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
if (res.data.code === 0) {
|
||||
const item = res.data.data
|
||||
console.log('Provider ', item)
|
||||
|
||||
await get().fetchProviderList()
|
||||
return item
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching provider:', error)
|
||||
|
||||
5
backend/app/exceptions/provider.py
Normal file
5
backend/app/exceptions/provider.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# exceptions.py
|
||||
class ConnectionTestError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.exceptions.provider import ConnectionTestError
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.model import ModelService
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
@@ -18,9 +19,7 @@ class ProviderRequest(BaseModel):
|
||||
type: str
|
||||
|
||||
class TestRequest(BaseModel):
|
||||
|
||||
api_key: str
|
||||
base_url:str
|
||||
id: str
|
||||
class ProviderUpdateRequest(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
@@ -33,14 +32,14 @@ class ProviderUpdateRequest(BaseModel):
|
||||
@router.post("/add_provider")
|
||||
def add_provider(data: ProviderRequest):
|
||||
try:
|
||||
ProviderService.add_provider(
|
||||
res = ProviderService.add_provider(
|
||||
name=data.name,
|
||||
api_key=data.api_key,
|
||||
base_url=data.base_url,
|
||||
logo=data.logo,
|
||||
type_=data.type
|
||||
)
|
||||
return R.success(msg='添加模型供应商成功')
|
||||
return R.success(msg='添加模型供应商成功',data=res)
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@@ -78,23 +77,20 @@ def update_provider(data: ProviderUpdateRequest):
|
||||
):
|
||||
return R.error(msg='请至少填写一个参数')
|
||||
|
||||
ProviderService.update_provider(
|
||||
provider_id =ProviderService.update_provider(
|
||||
id=data.id,
|
||||
data=dict(data)
|
||||
)
|
||||
return R.success(msg='更新模型供应商成功')
|
||||
return R.success(msg='更新模型供应商成功',data={'id': provider_id})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return R.error(msg=e)
|
||||
return R.error(msg=str(e))
|
||||
|
||||
@router.post('/connect_test')
|
||||
def gpt_connect_test(data:TestRequest):
|
||||
def gpt_connect_test(data: TestRequest):
|
||||
try:
|
||||
|
||||
res= ModelService().connect_test(data.api_key,data.base_url)
|
||||
if not res:
|
||||
return R.error(msg='连接失败')
|
||||
ModelService().connect_test(data.id)
|
||||
return R.success(msg='连接成功')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return R.error(msg=e)
|
||||
print("捕获到异常类型:", type(e))
|
||||
return R.error(msg=str(e))
|
||||
@@ -1,5 +1,6 @@
|
||||
from app.db.model_dao import insert_model, get_all_models
|
||||
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
|
||||
from app.db.provider_dao import get_enabled_providers
|
||||
from app.exceptions.provider import ConnectionTestError
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.models.model_config import ModelConfig
|
||||
@@ -70,6 +71,13 @@ class ModelService:
|
||||
})
|
||||
return formatted
|
||||
@staticmethod
|
||||
def get_enabled_models_by_provider( provider_id: str|int,):
|
||||
from app.db.model_dao import get_models_by_provider
|
||||
|
||||
all_models = get_models_by_provider(provider_id)
|
||||
enabled_models = all_models
|
||||
return enabled_models
|
||||
@staticmethod
|
||||
def get_all_models_by_id(provider_id: str, verbose: bool = False):
|
||||
try:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
@@ -86,13 +94,35 @@ class ModelService:
|
||||
print(f"[{provider_id}] 获取模型失败: {e}")
|
||||
return []
|
||||
@staticmethod
|
||||
def connect_test(api_key: str, base_url: str) -> bool:
|
||||
def connect_test(id: str) -> bool:
|
||||
try:
|
||||
return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url)
|
||||
except Exception as e:
|
||||
print(f"连接测试失败:{e}")
|
||||
return False
|
||||
provider = ProviderService.get_provider_by_id(id)
|
||||
|
||||
if provider:
|
||||
if not provider.get('api_key'):
|
||||
raise ConnectionTestError(f"供应商信息未找到,请先保存重试")
|
||||
result = OpenAICompatibleProvider.test_connection(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url')
|
||||
)
|
||||
if result:
|
||||
return True
|
||||
else:
|
||||
raise ConnectionTestError("请检查API Key 和 API 地址是否正确")
|
||||
|
||||
raise ConnectionTestError("供应商信息未找到,请先保存重试")
|
||||
except Exception as e:
|
||||
# 抛出业务异常,交由 Controller 处理
|
||||
raise ConnectionTestError(f"{str(e)}") from e
|
||||
|
||||
@staticmethod
|
||||
def delete_model_by_id( model_id: int) -> bool:
|
||||
try:
|
||||
delete_model(model_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[{model_id}] <UNK>: {e}")
|
||||
return False
|
||||
@staticmethod
|
||||
def add_new_model(provider_id: int, model_name: str) -> bool:
|
||||
try:
|
||||
@@ -102,6 +132,12 @@ class ModelService:
|
||||
print(f"供应商ID {provider_id} 不存在,无法添加模型")
|
||||
return False
|
||||
|
||||
# 查询是否已存在同名模型
|
||||
existing = get_model_by_provider_and_name(provider_id, model_name)
|
||||
if existing:
|
||||
print(f"模型 {model_name} 已存在于供应商ID {provider_id} 下,跳过插入")
|
||||
return False
|
||||
|
||||
# 插入模型
|
||||
insert_model(provider_id=provider_id, model_name=model_name)
|
||||
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")
|
||||
|
||||
@@ -82,15 +82,17 @@ class ProviderService:
|
||||
# all_models.extend(provider['models'])
|
||||
|
||||
@staticmethod
|
||||
def update_provider(id: str, data: dict):
|
||||
def update_provider(id: str, data: dict)->str | None:
|
||||
try:
|
||||
# 过滤掉空值
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
|
||||
print('更新模型供应商',filtered_data)
|
||||
return update_provider(id, **filtered_data)
|
||||
update_provider(id, **filtered_data)
|
||||
return id
|
||||
|
||||
except Exception as e:
|
||||
print('更新模型供应商失败:',e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def delete_provider(id: str):
|
||||
|
||||
Reference in New Issue
Block a user