diff --git a/BillNote_frontend/src/components/Form/modelForm/components/providerCard.tsx b/BillNote_frontend/src/components/Form/modelForm/components/providerCard.tsx index 55d2f24..4092af8 100644 --- a/BillNote_frontend/src/components/Form/modelForm/components/providerCard.tsx +++ b/BillNote_frontend/src/components/Form/modelForm/components/providerCard.tsx @@ -4,47 +4,51 @@ import styles from './index.module.css' import { useNavigate, useParams } from 'react-router-dom' import AILogo from '@/components/Form/modelForm/Icons' import { useProviderStore } from '@/store/providerStore' + export interface IProviderCardProps { id: string providerName: string Icon: string enable: number } + const ProviderCard: FC = ({ providerName, Icon, id, - enable, }: IProviderCardProps) => { const navigate = useNavigate() const updateProvider = useProviderStore(state => state.updateProvider) - const handleClick = () => { - navigate(`/settings/model/${id}`) - } - const handleEnable = () => { - console.log('enable', enable) + const enabled = useProviderStore(state => state.provider.find(p => p.id === id)?.enabled) + + const isChecked = enabled === 1 + + const handleToggle = (checked: boolean) => { + const allProviders = useProviderStore.getState().provider + const provider = allProviders.find(p => p.id === id) + if (!provider) return updateProvider({ - id, - enabled: enable == 1 ? 0 : 1, + ...provider, + enabled: checked ? 1 : 0, }) } - const rawId = useParams() - console.log('rawId', rawId) + // @ts-ignore const { id: currentId } = useParams() const isActive = currentId === id + return (
{ - handleClick() - }} className={ styles.card + ' flex h-14 items-center justify-between rounded border border-[#f3f3f3] p-2' + (isActive ? ' bg-[#F0F0F0] font-semibold text-blue-600' : '') } > -
+
navigate(`/settings/model/${id}`)} + >
@@ -53,11 +57,8 @@ const ProviderCard: FC = ({
{ - e.preventDefault() - handleEnable() - }} - checked={enable == 1} + checked={isChecked} + onCheckedChange={handleToggle} />
diff --git a/BillNote_frontend/src/store/providerStore/index.ts b/BillNote_frontend/src/store/providerStore/index.ts index 14044c7..d94e042 100644 --- a/BillNote_frontend/src/store/providerStore/index.ts +++ b/BillNote_frontend/src/store/providerStore/index.ts @@ -75,19 +75,19 @@ export const useProviderStore = create((set, get) => ({ getProviderById: id => get().provider.find(p => p.id === id), updateProvider: async (provider: IProvider) => { try { + const existing = get().provider.find(p => p.id === provider.id) + const merged = { ...existing, ...provider } + const data = { - ...provider, - api_key: provider.apiKey, - base_url: provider.baseUrl, - } - const res = await updateProviderById(data) - if (res.data.code === 0) { - const item = res.data.data - console.log('Provider ', item) - await get().fetchProviderList() + ...merged, + api_key: merged.apiKey, + base_url: merged.baseUrl, } + // 拦截器已解包:成功时直接返回 data 部分 + await updateProviderById(data) + await get().fetchProviderList() } catch (error) { - console.error('Error fetching provider:', error) + console.error('Error updating provider:', error) } }, getProviderList: () => get().provider, diff --git a/backend/app/db/model_dao.py b/backend/app/db/model_dao.py index ca8e459..1111e67 100644 --- a/backend/app/db/model_dao.py +++ b/backend/app/db/model_dao.py @@ -1,5 +1,6 @@ from app.db.engine import get_db from app.db.models.models import Model +from app.db.models.providers import Provider def get_model_by_provider_and_name(provider_id: int, model_name: str): @@ -58,7 +59,8 @@ def delete_model(model_id: int): def get_all_models(): db = next(get_db()) try: - models = db.query(Model).all() + # 只查询启用状态供应商的模型 + models = db.query(Model).join(Provider, Model.provider_id == Provider.id).filter(Provider.enabled == 1).all() return [ {"id": m.id, "provider_id": m.provider_id, "model_name": m.model_name} for m in models diff --git a/backend/app/routers/provider.py b/backend/app/routers/provider.py index 942a362..0b3215a 100644 --- a/backend/app/routers/provider.py +++ b/backend/app/routers/provider.py @@ -77,11 +77,14 @@ def update_provider(data: ProviderUpdateRequest): ): return R.error(msg='请至少填写一个参数') - provider_id =ProviderService.update_provider( + updated_provider =ProviderService.update_provider( id=data.id, data=dict(data) ) - return R.success(msg='更新模型供应商成功',data={'id': provider_id}) + if updated_provider: + return R.success(msg='更新模型供应商成功', data=updated_provider) + else: + return R.error(msg='更新模型供应商失败') except Exception as e: print(e) return R.error(msg=str(e)) diff --git a/backend/app/services/provider.py b/backend/app/services/provider.py index ca0a780..b6cef62 100644 --- a/backend/app/services/provider.py +++ b/backend/app/services/provider.py @@ -123,7 +123,12 @@ class ProviderService: filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'} print('更新模型供应商',filtered_data) update_provider(id, **filtered_data) - return id + # 获取更新后的供应商信息 + updated_provider = get_provider_by_id(id) + return { + 'id': id, + 'enabled': updated_provider.enabled, + } except Exception as e: print('更新模型供应商失败:',e)