From 9b298d309472f92b6a29e5a594c9ea74dc1a70d9 Mon Sep 17 00:00:00 2001 From: JefferyHcool <1063474837@qq.com> Date: Mon, 26 May 2025 23:16:19 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E5=A2=9E=E5=8A=A0=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=AE=A1=E7=90=86=E5=92=8C=E6=B5=8B=E8=AF=95=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增模型删除功能 - 实现模型测试连接功能 - 优化模型选择器组件 - 更新模型相关API和数据库操作 --- .../src/components/Form/modelForm/Form.tsx | 76 +++++++++++++++++-- .../Form/modelForm/ModelSelector.tsx | 4 +- .../src/pages/SettingPage/about.tsx | 2 +- BillNote_frontend/src/services/model.ts | 10 ++- .../src/store/modelStore/index.ts | 16 +++- .../src/store/providerStore/index.ts | 2 + backend/app/exceptions/provider.py | 5 ++ backend/app/routers/provider.py | 26 +++---- backend/app/services/model.py | 48 ++++++++++-- backend/app/services/provider.py | 6 +- 10 files changed, 158 insertions(+), 37 deletions(-) create mode 100644 backend/app/exceptions/provider.py diff --git a/BillNote_frontend/src/components/Form/modelForm/Form.tsx b/BillNote_frontend/src/components/Form/modelForm/Form.tsx index 45a6de1..4c1b7e6 100644 --- a/BillNote_frontend/src/components/Form/modelForm/Form.tsx +++ b/BillNote_frontend/src/components/Form/modelForm/Form.tsx @@ -16,7 +16,7 @@ import { useParams, useNavigate } from 'react-router-dom' import { useProviderStore } from '@/store/providerStore' import { useEffect, useState } from 'react' import toast from 'react-hot-toast' -import { testConnection, fetchModels } from '@/services/model.ts' +import { testConnection, fetchModels, deleteModelById } from '@/services/model.ts' import { Select, SelectContent, @@ -26,6 +26,9 @@ import { } from '@/components/ui/select.tsx' // ⚡新增 fetchModels import { ModelSelector } from '@/components/Form/modelForm/ModelSelector.tsx' import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert.tsx' +import { Tags } from 'lucide-react' +import { Tag } from 'antd' +import { useModelStore } from '@/store/modelStore' // ✅ Provider表单schema const ProviderSchema = z.object({ @@ -52,7 +55,7 @@ interface IModel { root: string } const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { - const { id } = useParams() + let { id } = useParams() const navigate = useNavigate() const isEditMode = !isCreate @@ -60,12 +63,16 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { const loadProviderById = useProviderStore(state => state.loadProviderById) const updateProvider = useProviderStore(state => state.updateProvider) const addNewProvider = useProviderStore(state => state.addNewProvider) - const [loading, setLoading] = useState(true) const [testing, setTesting] = useState(false) const [isBuiltIn, setIsBuiltIn] = useState(false) + const loadModelsById= useModelStore(state => state.loadModelsById) const [modelOptions, setModelOptions] = useState([]) // ⚡新增,保存模型列表 + const [models, setModels]= useState([]) const [modelLoading, setModelLoading] = useState(false) + const randomColor = ()=>{ + return '#' + Math.floor(Math.random() * 16777215).toString(16) + } const [search, setSearch] = useState('') const providerForm = useForm({ @@ -91,8 +98,10 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { }) useEffect(() => { + const load = async () => { if (isEditMode) { + const data = await loadProviderById(id!) providerForm.reset(data) setIsBuiltIn(data.type === 'built-in') @@ -105,11 +114,30 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { }) setIsBuiltIn(false) } + const models = await loadModelsById(id!) + if(models){ + console.log('🔧 模型列表:', models) + setModels(models) + + } setLoading(false) } load() }, [id]) + const handelDelete=async (modelId)=>{ + if (!window.confirm('确定要删除这个模型吗?')) return + try { + const res = await deleteModelById(modelId) + if (res.data.code === 0) { + toast.success('删除成功') + } else { + toast.error(res.data.msg || '删除失败') + } + } catch (e) { + toast.error('删除异常') + } + } // 测试连通性 const handleTest = async () => { const values = providerForm.getValues() @@ -118,10 +146,13 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { return } try { + if (!id){ + toast.error('请先保存供应商信息') + return + } setTesting(true) const data = await testConnection({ - api_key: values.apiKey, - base_url: values.baseUrl, + id }) if (data.data.code === 0) { toast.success('测试连通性成功 🎉') @@ -162,18 +193,21 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { // 保存Provider信息 const onProviderSubmit = async (values: ProviderFormValues) => { if (isEditMode) { - updateProvider({ ...values, id: id! }) + await updateProvider({ ...values, id: id! }) toast.success('更新供应商成功') } else { - addNewProvider({ ...values }) + id = await addNewProvider({ ...values }) + toast.success('新增供应商成功') } + // 刷新页面 + } // 保存Model信息 const onModelSubmit = async (values: ModelFormValues) => { - console.log('🔧 选择的模型:', values.modelName) toast.success(`保存模型: ${values.modelName}`) + await loadModelsById(id!) } if (loading) return
加载中...
@@ -267,6 +301,32 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => { + {/**/} + {/* {modelOptions.map(model => (*/} + {/* */} + +
+ 已启用模型 +
+ { + models && models.map(model => { + return ( + <> + { + handelDelete(model.id) + }} key={model.id} closable color={'blue'}> + {model.model_name} + + + ) + }) + } + +
+ {/**/} + {/**/} {/* {modelOptions.map(model => (*/} {/*
- {filteredModels.map(model => ( - + {filteredModels.map((model, index) => ( + {model.id} ))} diff --git a/BillNote_frontend/src/pages/SettingPage/about.tsx b/BillNote_frontend/src/pages/SettingPage/about.tsx index 7cf036b..bab4f0d 100644 --- a/BillNote_frontend/src/pages/SettingPage/about.tsx +++ b/BillNote_frontend/src/pages/SettingPage/about.tsx @@ -26,7 +26,7 @@ export default function AboutPage() { height={50} className="rounded-lg" /> -

BiliNote v1.5.0

+

BiliNote v1.7.2

AI 视频笔记生成工具 让 AI 为你的视频做笔记 diff --git a/BillNote_frontend/src/services/model.ts b/BillNote_frontend/src/services/model.ts index 4a40f1d..ec91d94 100644 --- a/BillNote_frontend/src/services/model.ts +++ b/BillNote_frontend/src/services/model.ts @@ -18,10 +18,14 @@ export const testConnection = async (data: any) => { return await request.post('/connect_test', data) } -export const fetchModels = async (providerId: any) => { +export const fetchModels = async (providerId: string) => { return await request.get('/model_list/' + providerId) } +export const fetchEnableModelById = async (id: string) => { + return await request.get('/model_enable/' + id) +} + export async function addModel(data: { provider_id: string; model_name: string }) { return request.post('/models', data) } @@ -29,3 +33,7 @@ export async function addModel(data: { provider_id: string; model_name: string } export const fetchEnableModels = async () => { return await request.get('/model_list') } + +export const deleteModelById = async (modelId: number) => { + return await request.get(`/models/delete/${modelId}`) +} \ No newline at end of file diff --git a/BillNote_frontend/src/store/modelStore/index.ts b/BillNote_frontend/src/store/modelStore/index.ts index 0e0b7c8..b747757 100644 --- a/BillNote_frontend/src/store/modelStore/index.ts +++ b/BillNote_frontend/src/store/modelStore/index.ts @@ -1,6 +1,6 @@ import { create } from 'zustand' import { devtools } from 'zustand/middleware' -import { fetchModels, addModel, fetchEnableModels } from '@/services/model.ts' +import { fetchModels, addModel, fetchEnableModels, fetchEnableModelById, deleteModelById } from '@/services/model.ts' interface IModel { id: string @@ -18,8 +18,10 @@ interface ModelStore { selectedModel: string loadModels: (providerId: string) => Promise loadEnabledModels: () => Promise + loadModelsById : (providerId: string) => Promise addNewModel: (providerId: string, modelId: string) => Promise setSelectedModel: (modelId: string) => void + deleteModel: (modelId: number) => Promise clearModels: () => void } @@ -45,6 +47,10 @@ export const useModelStore = create()( console.error('加载模型出错', error) } }, + + deleteModel: async (modelId: number) => { + await deleteModelById( modelId) + }, // 加载模型列表 loadModels: async (providerId: string) => { try { @@ -65,7 +71,13 @@ export const useModelStore = create()( set({ loading: false }) } }, - + loadModelsById: async (providerId: string)=>{ + const models = await fetchEnableModelById(providerId) + if (models.data.code === 0) { + console.log('模型列表加载成功:', models.data) + return models.data.data + } + }, // 新增模型 addNewModel: async (providerId: string, modelId: string) => { try { diff --git a/BillNote_frontend/src/store/providerStore/index.ts b/BillNote_frontend/src/store/providerStore/index.ts index 36819b2..f61465d 100644 --- a/BillNote_frontend/src/store/providerStore/index.ts +++ b/BillNote_frontend/src/store/providerStore/index.ts @@ -66,7 +66,9 @@ export const useProviderStore = create((set, get) => ({ if (res.data.code === 0) { const item = res.data.data console.log('Provider ', item) + await get().fetchProviderList() + return item } } catch (error) { console.error('Error fetching provider:', error) diff --git a/backend/app/exceptions/provider.py b/backend/app/exceptions/provider.py new file mode 100644 index 0000000..a03296e --- /dev/null +++ b/backend/app/exceptions/provider.py @@ -0,0 +1,5 @@ +# exceptions.py +class ConnectionTestError(Exception): + def __init__(self, message: str): + super().__init__(message) + self.message = message \ No newline at end of file diff --git a/backend/app/routers/provider.py b/backend/app/routers/provider.py index e934ba3..9710082 100644 --- a/backend/app/routers/provider.py +++ b/backend/app/routers/provider.py @@ -2,6 +2,7 @@ from typing import Optional from fastapi import APIRouter from pydantic import BaseModel +from app.exceptions.provider import ConnectionTestError from app.models.model_config import ModelConfig from app.services.model import ModelService from app.utils.response import ResponseWrapper as R @@ -18,9 +19,7 @@ class ProviderRequest(BaseModel): type: str class TestRequest(BaseModel): - - api_key: str - base_url:str + id: str class ProviderUpdateRequest(BaseModel): id: str name: Optional[str] = None @@ -33,14 +32,14 @@ class ProviderUpdateRequest(BaseModel): @router.post("/add_provider") def add_provider(data: ProviderRequest): try: - ProviderService.add_provider( + res = ProviderService.add_provider( name=data.name, api_key=data.api_key, base_url=data.base_url, logo=data.logo, type_=data.type ) - return R.success(msg='添加模型供应商成功') + return R.success(msg='添加模型供应商成功',data=res) except Exception as e: return R.error(msg=e) @@ -78,23 +77,20 @@ def update_provider(data: ProviderUpdateRequest): ): return R.error(msg='请至少填写一个参数') - ProviderService.update_provider( + provider_id =ProviderService.update_provider( id=data.id, data=dict(data) ) - return R.success(msg='更新模型供应商成功') + return R.success(msg='更新模型供应商成功',data={'id': provider_id}) except Exception as e: print(e) - return R.error(msg=e) + return R.error(msg=str(e)) @router.post('/connect_test') -def gpt_connect_test(data:TestRequest): +def gpt_connect_test(data: TestRequest): try: - - res= ModelService().connect_test(data.api_key,data.base_url) - if not res: - return R.error(msg='连接失败') + ModelService().connect_test(data.id) return R.success(msg='连接成功') except Exception as e: - print(e) - return R.error(msg=e) \ No newline at end of file + print("捕获到异常类型:", type(e)) + return R.error(msg=str(e)) \ No newline at end of file diff --git a/backend/app/services/model.py b/backend/app/services/model.py index 61e4112..0c60d10 100644 --- a/backend/app/services/model.py +++ b/backend/app/services/model.py @@ -1,5 +1,6 @@ -from app.db.model_dao import insert_model, get_all_models +from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model from app.db.provider_dao import get_enabled_providers +from app.exceptions.provider import ConnectionTestError from app.gpt.gpt_factory import GPTFactory from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider from app.models.model_config import ModelConfig @@ -70,6 +71,13 @@ class ModelService: }) return formatted @staticmethod + def get_enabled_models_by_provider( provider_id: str|int,): + from app.db.model_dao import get_models_by_provider + + all_models = get_models_by_provider(provider_id) + enabled_models = all_models + return enabled_models + @staticmethod def get_all_models_by_id(provider_id: str, verbose: bool = False): try: provider = ProviderService.get_provider_by_id(provider_id) @@ -86,13 +94,35 @@ class ModelService: print(f"[{provider_id}] 获取模型失败: {e}") return [] @staticmethod - def connect_test(api_key: str, base_url: str) -> bool: + def connect_test(id: str) -> bool: try: - return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url) - except Exception as e: - print(f"连接测试失败:{e}") - return False + provider = ProviderService.get_provider_by_id(id) + if provider: + if not provider.get('api_key'): + raise ConnectionTestError(f"供应商信息未找到,请先保存重试") + result = OpenAICompatibleProvider.test_connection( + api_key=provider.get('api_key'), + base_url=provider.get('base_url') + ) + if result: + return True + else: + raise ConnectionTestError("请检查API Key 和 API 地址是否正确") + + raise ConnectionTestError("供应商信息未找到,请先保存重试") + except Exception as e: + # 抛出业务异常,交由 Controller 处理 + raise ConnectionTestError(f"{str(e)}") from e + + @staticmethod + def delete_model_by_id( model_id: int) -> bool: + try: + delete_model(model_id) + return True + except Exception as e: + print(f"[{model_id}] : {e}") + return False @staticmethod def add_new_model(provider_id: int, model_name: str) -> bool: try: @@ -102,6 +132,12 @@ class ModelService: print(f"供应商ID {provider_id} 不存在,无法添加模型") return False + # 查询是否已存在同名模型 + existing = get_model_by_provider_and_name(provider_id, model_name) + if existing: + print(f"模型 {model_name} 已存在于供应商ID {provider_id} 下,跳过插入") + return False + # 插入模型 insert_model(provider_id=provider_id, model_name=model_name) print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}") diff --git a/backend/app/services/provider.py b/backend/app/services/provider.py index 48d0909..d953fdd 100644 --- a/backend/app/services/provider.py +++ b/backend/app/services/provider.py @@ -82,15 +82,17 @@ class ProviderService: # all_models.extend(provider['models']) @staticmethod - def update_provider(id: str, data: dict): + def update_provider(id: str, data: dict)->str | None: try: # 过滤掉空值 filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'} print('更新模型供应商',filtered_data) - return update_provider(id, **filtered_data) + update_provider(id, **filtered_data) + return id except Exception as e: print('更新模型供应商失败:',e) + return None @staticmethod def delete_provider(id: str):