feat: Validate model type and model name compatibility

This commit is contained in:
lilong.129
2025-05-26 09:40:28 +08:00
parent 4e74247cab
commit f20fdd51bc
3 changed files with 26 additions and 2 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2505260928
v5.0.0-beta-2505260940

View File

@@ -25,7 +25,7 @@ import (
// NewChat creates a new chat session
func (h *MCPHost) NewChat(ctx context.Context) (*Chat, error) {
// Get model config from environment variables
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeUITARS)
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeDoubaoVL)
if err != nil {
return nil, err
}

View File

@@ -2,7 +2,9 @@ package ai
import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
@@ -95,6 +97,11 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
"env %s missed", EnvModelName)
}
// Validate model type and model name compatibility
if err := validateModelType(modelType, modelName); err != nil {
return nil, err
}
// https://www.volcengine.com/docs/82379/1536429
temperature := float32(0)
topP := float32(0.7)
@@ -120,6 +127,23 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
}, nil
}
func validateModelType(modelType option.LLMServiceType, modelName string) error {
switch modelType {
case option.LLMServiceTypeUITARS:
if !strings.Contains(modelName, "ui-tars") {
return fmt.Errorf("model name %s is not supported for %s", modelName, modelType)
}
return nil
case option.LLMServiceTypeDoubaoVL:
if !strings.Contains(modelName, "doubao") || !strings.Contains(modelName, "vision") {
return fmt.Errorf("model name %s is not supported", modelName)
}
return nil
}
return fmt.Errorf("model type %s is not supported", modelType)
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {