From fcddcfb6303dc80e4b02d135b46e0a2be07191fd Mon Sep 17 00:00:00 2001 From: "lilong.129" Date: Wed, 30 Apr 2025 15:17:01 +0800 Subject: [PATCH] refactor: GetModelConfig --- internal/version/VERSION | 2 +- uixt/ai/ai.go | 27 +++++++++++++++++++-------- uixt/ai/asserter.go | 28 ++++++++++++---------------- uixt/ai/asserter_test.go | 5 ++++- uixt/ai/planner.go | 28 +++++++++++----------------- uixt/ai/planner_test.go | 20 ++++++++++++++++---- 6 files changed, 63 insertions(+), 47 deletions(-) diff --git a/internal/version/VERSION b/internal/version/VERSION index 640b4910..695718cf 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2504301431 +v5.0.0-beta-2504301521 diff --git a/uixt/ai/ai.go b/uixt/ai/ai.go index 75f6cf56..5f011a7b 100644 --- a/uixt/ai/ai.go +++ b/uixt/ai/ai.go @@ -20,11 +20,16 @@ type ILLMService interface { } func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) { - planner, err := NewPlanner(context.Background(), modelType) + modelConfig, err := GetModelConfig(modelType) if err != nil { return nil, err } - asserter, err := NewAsserter(context.Background()) + + planner, err := NewPlanner(context.Background(), modelConfig) + if err != nil { + return nil, err + } + asserter, err := NewAsserter(context.Background(), modelConfig) if err != nil { return nil, err } @@ -36,6 +41,7 @@ func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) { } // combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能 +// ⭐️支持采用不同的模型服务进行规划和断言 type combinedLLMService struct { planner IPlanner // 提供规划功能 asserter IAsserter // 提供断言功能 @@ -58,18 +64,20 @@ const ( EnvModelName = "LLM_MODEL_NAME" ) -var EnvModelUse string - const ( defaultTimeout = 30 * time.Second ) -// GetOpenAIModelConfig get OpenAI config -func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { +type ModelConfig struct { + *openai.ChatModelConfig + ModelType option.LLMServiceType +} + +// GetModelConfig get OpenAI config +func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) { if err := config.LoadEnv(); err != nil { return nil, errors.Wrap(code.LoadEnvError, err.Error()) } - EnvModelUse = os.Getenv("LLM_MODEL_USE") openaiBaseURL := os.Getenv(EnvOpenAIBaseURL) if openaiBaseURL == "" { @@ -103,7 +111,10 @@ func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { Str("timeout", defaultTimeout.String()). Msg("get model config") - return modelConfig, nil + return &ModelConfig{ + ChatModelConfig: modelConfig, + ModelType: modelType, + }, nil } // maskAPIKey masks the API key diff --git a/uixt/ai/asserter.go b/uixt/ai/asserter.go index 678a3d9d..214b1c5d 100644 --- a/uixt/ai/asserter.go +++ b/uixt/ai/asserter.go @@ -41,28 +41,23 @@ type AssertionResponse struct { // Asserter handles assertion using different AI models type Asserter struct { ctx context.Context + modelConfig *ModelConfig model model.ToolCallingChatModel systemPrompt string history ConversationHistory } // NewAsserter creates a new Asserter instance -func NewAsserter(ctx context.Context) (*Asserter, error) { +func NewAsserter(ctx context.Context, modelConfig *ModelConfig) (*Asserter, error) { asserter := &Asserter{ ctx: ctx, + modelConfig: modelConfig, systemPrompt: defaultAssertionPrompt, } - config, err := GetOpenAIModelConfig() - if err != nil { - return nil, err - } - - if strings.Contains(EnvModelUse, string(option.LLMServiceTypeUITARS)) { + if modelConfig.ModelType == option.LLMServiceTypeUITARS { asserter.systemPrompt += "\n\n" + uiTarsAssertionResponseFormat - } else if strings.Contains(EnvModelUse, string(option.LLMServiceTypeQwenVL)) { - asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat - } else if strings.Contains(EnvModelUse, string(option.LLMServiceTypeGPT)) { + } else if modelConfig.ModelType == option.LLMServiceTypeGPT { // define output format type OutputFormat struct { Thought string `json:"thought"` @@ -71,11 +66,11 @@ func NewAsserter(ctx context.Context) (*Asserter, error) { } outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil) if err != nil { - return nil, err + return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error()) } // set structured response format // https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go - config.ResponseFormat = &openai2.ChatCompletionResponseFormat{ + modelConfig.ChatModelConfig.ResponseFormat = &openai2.ChatCompletionResponseFormat{ Type: openai2.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{ Name: "assertion_result", @@ -85,12 +80,13 @@ func NewAsserter(ctx context.Context) (*Asserter, error) { }, } } else { - return nil, fmt.Errorf("model type %s not supported for asserter", EnvModelUse) + asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat } - asserter.model, err = openai.NewChatModel(ctx, config) + var err error + asserter.model, err = openai.NewChatModel(ctx, modelConfig.ChatModelConfig) if err != nil { - return nil, err + return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error()) } return asserter, nil @@ -142,7 +138,7 @@ Here is the assertion. Please tell whether it is truthy according to the screens startTime := time.Now() resp, err := a.model.Generate(a.ctx, a.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", EnvModelUse).Msg("call model service for assertion") + Str("model", string(a.modelConfig.ModelType)).Msg("call model service for assertion") if err != nil { return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } diff --git a/uixt/ai/asserter_test.go b/uixt/ai/asserter_test.go index 981216be..3587bc6f 100644 --- a/uixt/ai/asserter_test.go +++ b/uixt/ai/asserter_test.go @@ -4,13 +4,16 @@ import ( "context" "testing" + "github.com/httprunner/httprunner/v5/uixt/option" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func createAsserter(t *testing.T) *Asserter { - asserter, err := NewAsserter(context.Background()) + modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS) + require.NoError(t, err) + asserter, err := NewAsserter(context.Background(), modelConfig) require.NoError(t, err) return asserter } diff --git a/uixt/ai/planner.go b/uixt/ai/planner.go index 00b76e80..228c826c 100644 --- a/uixt/ai/planner.go +++ b/uixt/ai/planner.go @@ -2,7 +2,6 @@ package ai import ( "context" - "fmt" "time" "github.com/cloudwego/eino-ext/components/model/openai" @@ -33,26 +32,22 @@ type PlanningResult struct { Error string `json:"error,omitempty"` } -func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner, error) { +func NewPlanner(ctx context.Context, modelConfig *ModelConfig) (*Planner, error) { planner := &Planner{ - ctx: ctx, - modelType: modelType, + ctx: ctx, + modelConfig: modelConfig, } - config, err := GetOpenAIModelConfig() - if err != nil { - return nil, fmt.Errorf("failed to create OpenAI config: %w", err) - } - - if modelType == option.LLMServiceTypeUITARS { + if modelConfig.ModelType == option.LLMServiceTypeUITARS { planner.systemPrompt = uiTarsPlanningPrompt } else { planner.systemPrompt = defaultPlanningResponseJsonFormat } - planner.model, err = openai.NewChatModel(ctx, config) + var err error + planner.model, err = openai.NewChatModel(ctx, modelConfig.ChatModelConfig) if err != nil { - return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err) + return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error()) } return planner, nil @@ -60,9 +55,9 @@ func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner, type Planner struct { ctx context.Context + modelConfig *ModelConfig model model.ToolCallingChatModel systemPrompt string - modelType option.LLMServiceType history ConversationHistory } @@ -76,11 +71,10 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { // prepare prompt if len(p.history) == 0 { // add system message - systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction p.history = ConversationHistory{ { Role: schema.System, - Content: systemPrompt, + Content: p.systemPrompt + opts.UserInstruction, }, } } @@ -92,7 +86,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { startTime := time.Now() resp, err := p.model.Generate(p.ctx, p.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", string(p.modelType)).Msg("call model service") + Str("model", string(p.modelConfig.ModelType)).Msg("call model service") if err != nil { return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } @@ -116,7 +110,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) { var parseActions []ParsedAction var err error - if p.modelType == option.LLMServiceTypeUITARS { + if p.modelConfig.ModelType == option.LLMServiceTypeUITARS { // parse Thought/Action format from UI-TARS parseActions, err = parseThoughtAction(msg.Content) if err != nil { diff --git a/uixt/ai/planner_test.go b/uixt/ai/planner_test.go index 9f3faddb..98ac862e 100644 --- a/uixt/ai/planner_test.go +++ b/uixt/ai/planner_test.go @@ -36,7 +36,10 @@ func TestVLMPlanning(t *testing.T) { userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标" - planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) + modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS) + require.NoError(t, err) + + planner, err := NewPlanner(context.Background(), modelConfig) require.NoError(t, err) opts := &PlanningOptions{ @@ -106,7 +109,10 @@ func TestXHSPlanning(t *testing.T) { userInstruction := "点击第二个帖子的作者头像" - planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) + modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS) + require.NoError(t, err) + + planner, err := NewPlanner(context.Background(), modelConfig) require.NoError(t, err) opts := &PlanningOptions{ @@ -176,7 +182,10 @@ func TestChatList(t *testing.T) { userInstruction := "请结合图片的文字信息,请告诉我一共有多少个群聊,哪些群聊右下角有绿点" - planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) + modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS) + require.NoError(t, err) + + planner, err := NewPlanner(context.Background(), modelConfig) require.NoError(t, err) opts := &PlanningOptions{ @@ -207,7 +216,10 @@ func TestHandleSwitch(t *testing.T) { userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关 // 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启 - planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) + modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS) + require.NoError(t, err) + + planner, err := NewPlanner(context.Background(), modelConfig) require.NoError(t, err) testCases := []struct {