Files
httprunner/uixt/ai/ai.go
2025-06-05 18:09:25 +08:00

155 lines
4.0 KiB
Go

package ai
import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
// ILLMService 定义了 LLM 服务接口,包括规划和断言功能
type ILLMService interface {
Call(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error)
Assert(ctx context.Context, opts *AssertOptions) (*AssertionResult, error)
}
func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) {
modelConfig, err := GetModelConfig(modelType)
if err != nil {
return nil, err
}
planner, err := NewPlanner(context.Background(), modelConfig)
if err != nil {
return nil, err
}
asserter, err := NewAsserter(context.Background(), modelConfig)
if err != nil {
return nil, err
}
return &combinedLLMService{
planner: planner,
asserter: asserter,
}, nil
}
// combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能
// ⭐️支持采用不同的模型服务进行规划和断言
type combinedLLMService struct {
planner IPlanner // 提供规划功能
asserter IAsserter // 提供断言功能
}
// Call 执行规划功能
func (c *combinedLLMService) Call(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error) {
return c.planner.Call(ctx, opts)
}
// Assert 执行断言功能
func (c *combinedLLMService) Assert(ctx context.Context, opts *AssertOptions) (*AssertionResult, error) {
return c.asserter.Assert(ctx, opts)
}
// LLM model config env variables
const (
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
EnvOpenAIAPIKey = "OPENAI_API_KEY"
EnvModelName = "LLM_MODEL_NAME"
)
const (
defaultTimeout = 30 * time.Second
)
type ModelConfig struct {
*openai.ChatModelConfig
ModelType option.LLMServiceType
}
// GetModelConfig get OpenAI config
func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
if openaiBaseURL == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIBaseURL)
}
openaiAPIKey := os.Getenv(EnvOpenAIAPIKey)
if openaiAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIAPIKey)
}
modelName := os.Getenv(EnvModelName)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvModelName)
}
// Validate model type and model name compatibility
if err := validateModelType(modelType, modelName); err != nil {
return nil, err
}
// https://www.volcengine.com/docs/82379/1536429
temperature := float32(0)
topP := float32(0.7)
modelConfig := &openai.ChatModelConfig{
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
Temperature: &temperature,
TopP: &topP,
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return &ModelConfig{
ChatModelConfig: modelConfig,
ModelType: modelType,
}, nil
}
func validateModelType(modelType option.LLMServiceType, modelName string) error {
switch modelType {
case option.DOUBAO_1_5_UI_TARS_250428:
if !strings.Contains(modelName, "ui-tars") {
return fmt.Errorf("model name %s is not supported for %s", modelName, modelType)
}
return nil
case option.DOUBAO_1_5_THINKING_VISION_PRO_250428:
if !strings.Contains(modelName, "doubao") || !strings.Contains(modelName, "vision") {
return fmt.Errorf("model name %s is not supported", modelName)
}
return nil
}
return fmt.Errorf("model type %s is not supported", modelType)
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {
return "******"
}
return key[:4] + "******" + key[len(key)-4:]
}