From f20fdd51bcba6ca926be4ad420ba0979d27b24ec Mon Sep 17 00:00:00 2001 From: "lilong.129" Date: Mon, 26 May 2025 09:40:28 +0800 Subject: [PATCH] feat: Validate model type and model name compatibility --- internal/version/VERSION | 2 +- mcphost/chat.go | 2 +- uixt/ai/ai.go | 24 ++++++++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/internal/version/VERSION b/internal/version/VERSION index 9a5d9eb6..a2398a92 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2505260928 +v5.0.0-beta-2505260940 diff --git a/mcphost/chat.go b/mcphost/chat.go index 2dac9da2..60e45c59 100644 --- a/mcphost/chat.go +++ b/mcphost/chat.go @@ -25,7 +25,7 @@ import ( // NewChat creates a new chat session func (h *MCPHost) NewChat(ctx context.Context) (*Chat, error) { // Get model config from environment variables - modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeUITARS) + modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeDoubaoVL) if err != nil { return nil, err } diff --git a/uixt/ai/ai.go b/uixt/ai/ai.go index fa1f6a8e..428490dc 100644 --- a/uixt/ai/ai.go +++ b/uixt/ai/ai.go @@ -2,7 +2,9 @@ package ai import ( "context" + "fmt" "os" + "strings" "time" "github.com/cloudwego/eino-ext/components/model/openai" @@ -95,6 +97,11 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) { "env %s missed", EnvModelName) } + // Validate model type and model name compatibility + if err := validateModelType(modelType, modelName); err != nil { + return nil, err + } + // https://www.volcengine.com/docs/82379/1536429 temperature := float32(0) topP := float32(0.7) @@ -120,6 +127,23 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) { }, nil } +func validateModelType(modelType option.LLMServiceType, modelName string) error { + switch modelType { + case option.LLMServiceTypeUITARS: + if !strings.Contains(modelName, "ui-tars") { + return fmt.Errorf("model name %s is not supported for %s", modelName, modelType) + } + return nil + case option.LLMServiceTypeDoubaoVL: + if !strings.Contains(modelName, "doubao") || !strings.Contains(modelName, "vision") { + return fmt.Errorf("model name %s is not supported", modelName) + } + return nil + } + + return fmt.Errorf("model type %s is not supported", modelType) +} + // maskAPIKey masks the API key func maskAPIKey(key string) string { if len(key) <= 8 {