diff --git a/internal/version/VERSION b/internal/version/VERSION index 4f413300..13a7f720 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2504291210 +v5.0.0-beta-2504292008 diff --git a/uixt/ai/ai.go b/uixt/ai/ai.go index 78b43f77..3d039820 100644 --- a/uixt/ai/ai.go +++ b/uixt/ai/ai.go @@ -49,6 +49,8 @@ type LLMServiceType string const ( LLMServiceTypeUITARS LLMServiceType = "ui-tars" LLMServiceTypeGPT4o LLMServiceType = "gpt-4o" + LLMServiceTypeGPT4Vision LLMServiceType = "gpt-4-vision" + LLMServiceTypeQwenVL LLMServiceType = "qwen-vl" LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3" ) @@ -58,45 +60,33 @@ type ILLMService interface { Assert(opts *AssertOptions) (*AssertionResponse, error) } -func WithLLMService(service LLMServiceType) AIServiceOption { +func WithLLMService(modelType LLMServiceType) AIServiceOption { return func(opts *AIServices) { - switch service { + // init planner + var planner IPlanner + var err error + switch modelType { case LLMServiceTypeGPT4o: // TODO: implement gpt-4o planner and asserter - planner, err := NewPlanner(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init gpt-4o planner failed") - os.Exit(code.GetErrorCode(err)) - } - - asserter, err := NewUITarsAsserter(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init ui-tars asserter failed") - os.Exit(code.GetErrorCode(err)) - } - - opts.ILLMService = &combinedLLMService{ - planner: planner, - asserter: asserter, - } - + planner, err = NewPlanner(context.Background()) case LLMServiceTypeUITARS: - planner, err := NewUITarsPlanner(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init ui-tars planner failed") - os.Exit(code.GetErrorCode(err)) - } + planner, err = NewUITarsPlanner(context.Background()) + } + if err != nil { + log.Error().Err(err).Msgf("init %s planner failed", modelType) + os.Exit(code.GetErrorCode(err)) + } - asserter, err := NewUITarsAsserter(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init ui-tars asserter failed") - os.Exit(code.GetErrorCode(err)) - } + // init asserter + asserter, err := NewAsserter(context.Background(), modelType) + if err != nil { + log.Error().Err(err).Msgf("init %s asserter failed", modelType) + os.Exit(code.GetErrorCode(err)) + } - opts.ILLMService = &combinedLLMService{ - planner: planner, - asserter: asserter, - } + opts.ILLMService = &combinedLLMService{ + planner: planner, + asserter: asserter, } } } diff --git a/uixt/ai/ai_ark.go b/uixt/ai/ai_ark.go new file mode 100644 index 00000000..265d4afb --- /dev/null +++ b/uixt/ai/ai_ark.go @@ -0,0 +1,62 @@ +package ai + +import ( + "os" + + "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/internal/config" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +const ( + EnvArkBaseURL = "ARK_BASE_URL" + EnvArkAPIKey = "ARK_API_KEY" + EnvArkModelID = "ARK_MODEL_ID" +) + +func GetArkModelConfig() (*ark.ChatModelConfig, error) { + if err := config.LoadEnv(); err != nil { + return nil, errors.Wrap(code.LoadEnvError, err.Error()) + } + + arkBaseURL := os.Getenv(EnvArkBaseURL) + arkAPIKey := os.Getenv(EnvArkAPIKey) + if arkAPIKey == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvArkAPIKey) + } + modelName := os.Getenv(EnvArkModelID) + if modelName == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvArkModelID) + } + timeout := defaultTimeout + + // https://www.volcengine.com/docs/82379/1494384?redirect=1 + temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。 + // topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。 + // maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。 + // frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。 + + modelConfig := &ark.ChatModelConfig{ + BaseURL: arkBaseURL, + APIKey: arkAPIKey, + Model: modelName, + Timeout: &timeout, + Temperature: &temperature, + // TopP: &topP, + // MaxTokens: &maxTokens, + // FrequencyPenalty: &frequencyPenalty, + } + + // log config info + log.Info().Str("model", modelConfig.Model). + Str("baseURL", modelConfig.BaseURL). + Str("apiKey", maskAPIKey(modelConfig.APIKey)). + Str("timeout", defaultTimeout.String()). + Msg("get model config") + + return modelConfig, nil +} diff --git a/uixt/ai/ai_openai.go b/uixt/ai/ai_openai.go new file mode 100644 index 00000000..85d35ab4 --- /dev/null +++ b/uixt/ai/ai_openai.go @@ -0,0 +1,79 @@ +package ai + +import ( + "os" + + "github.com/cloudwego/eino-ext/components/model/openai" + openai2 "github.com/cloudwego/eino-ext/libs/acl/openai" + "github.com/getkin/kin-openapi/openapi3gen" + "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/internal/config" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +const ( + EnvOpenAIBaseURL = "OPENAI_BASE_URL" + EnvOpenAIAPIKey = "OPENAI_API_KEY" + EnvModelName = "LLM_MODEL_NAME" +) + +// GetOpenAIModelConfig get OpenAI config +func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { + if err := config.LoadEnv(); err != nil { + return nil, errors.Wrap(code.LoadEnvError, err.Error()) + } + + openaiBaseURL := os.Getenv(EnvOpenAIBaseURL) + if openaiBaseURL == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvOpenAIBaseURL) + } + openaiAPIKey := os.Getenv(EnvOpenAIAPIKey) + if openaiAPIKey == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvOpenAIAPIKey) + } + modelName := os.Getenv(EnvModelName) + if modelName == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvModelName) + } + + type OutputFormat struct { + Thought string `json:"thought"` + Action string `json:"action"` + Error string `json:"error,omitempty"` + } + outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil) + if err != nil { + return nil, err + } + + modelConfig := &openai.ChatModelConfig{ + BaseURL: openaiBaseURL, + APIKey: openaiAPIKey, + Model: modelName, + Timeout: defaultTimeout, + // set structured response format + // https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go + ResponseFormat: &openai2.ChatCompletionResponseFormat{ + Type: openai2.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{ + Name: "thought_and_action", + Description: "data that describes planning thought and action", + Schema: outputFormatSchema.Value, + Strict: false, + }, + }, + } + + // log config info + log.Info().Str("model", modelConfig.Model). + Str("baseURL", modelConfig.BaseURL). + Str("apiKey", maskAPIKey(modelConfig.APIKey)). + Str("timeout", defaultTimeout.String()). + Msg("get model config") + + return modelConfig, nil +} diff --git a/uixt/ai/asserter_ui_tars.go b/uixt/ai/asserter.go similarity index 51% rename from uixt/ai/asserter_ui_tars.go rename to uixt/ai/asserter.go index 530b317b..9e0d628b 100644 --- a/uixt/ai/asserter_ui_tars.go +++ b/uixt/ai/asserter.go @@ -8,6 +8,8 @@ import ( "time" "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" "github.com/httprunner/httprunner/v5/code" "github.com/httprunner/httprunner/v5/internal/json" @@ -16,60 +18,11 @@ import ( "github.com/rs/zerolog/log" ) +// IAsserter interface defines the contract for assertion operations type IAsserter interface { Assert(opts *AssertOptions) (*AssertionResponse, error) } -// UI-TARS assertion system prompt -const uiTarsAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy. - -## Output Json String Format -` + "```" + ` -"{ - "pass": <>, - "thought": "<>" -}" -` + "```" + ` - -## Rules **MUST** follow -- Make sure to return **only** the JSON, with **no additional** text or explanations. -- Use Chinese in 'thought' part. -- You **MUST** strictly follow up the **Output Json String Format**.` - -// AssertionResponse represents the response from an AI assertion -type AssertionResponse struct { - Pass bool `json:"pass"` - Thought string `json:"thought"` -} - -// UITarsAsserter handles assertion using UI-TARS VLM -type UITarsAsserter struct { - ctx context.Context - model *ark.ChatModel - config *ark.ChatModelConfig - systemPrompt string - history ConversationHistory -} - -// NewUITarsAsserter creates a new UITarsAsserter instance -func NewUITarsAsserter(ctx context.Context) (*UITarsAsserter, error) { - config, err := GetArkModelConfig() - if err != nil { - return nil, err - } - chatModel, err := ark.NewChatModel(ctx, config) - if err != nil { - return nil, err - } - - return &UITarsAsserter{ - ctx: ctx, - config: config, - model: chatModel, - systemPrompt: uiTarsAssertionPrompt, - }, nil -} - // AssertOptions represents the input options for assertion type AssertOptions struct { Assertion string `json:"assertion"` // The assertion text to verify @@ -77,18 +30,65 @@ type AssertOptions struct { Size types.Size `json:"size"` // Screen dimensions } -func validateAssertionInput(opts *AssertOptions) error { - if opts.Assertion == "" { - return errors.Wrap(code.LLMPrepareRequestError, "assertion text is required") +// AssertionResponse represents the response from an AI assertion +type AssertionResponse struct { + Pass bool `json:"pass"` + Thought string `json:"thought"` +} + +// Asserter handles assertion using different AI models +type Asserter struct { + ctx context.Context + model model.ToolCallingChatModel + systemPrompt string + history ConversationHistory + modelType LLMServiceType +} + +// NewAsserter creates a new Asserter instance +func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, error) { + asserter := &Asserter{ + ctx: ctx, + modelType: modelType, + systemPrompt: getAssertionSystemPrompt(modelType), } - if opts.Screenshot == "" { - return errors.Wrap(code.LLMPrepareRequestError, "screenshot is required") + + switch modelType { + case LLMServiceTypeUITARS: + config, err := GetArkModelConfig() + if err != nil { + return nil, err + } + asserter.model, err = ark.NewChatModel(ctx, config) + if err != nil { + return nil, err + } + case LLMServiceTypeGPT4Vision, LLMServiceTypeGPT4o: + config, err := GetOpenAIModelConfig() + if err != nil { + return nil, err + } + asserter.model, err = openai.NewChatModel(ctx, config) + if err != nil { + return nil, err + } + default: + return nil, errors.New("not supported model type for asserter") } - return nil + + return asserter, nil +} + +// getAssertionSystemPrompt returns the appropriate system prompt for the given model type +func getAssertionSystemPrompt(modelType LLMServiceType) string { + if modelType == LLMServiceTypeUITARS { + return defaultAssertionPrompt + "\n\n" + uiTarsAssertionResponseFormat + } + return defaultAssertionPrompt + "\n\n" + defaultAssertionResponseJsonFormat } // Assert performs the assertion check on the screenshot -func (a *UITarsAsserter) Assert(opts *AssertOptions) (*AssertionResponse, error) { +func (a *Asserter) Assert(opts *AssertOptions) (*AssertionResponse, error) { // Validate input parameters if err := validateAssertionInput(opts); err != nil { return nil, errors.Wrap(err, "validate assertion parameters failed") @@ -133,7 +133,7 @@ Here is the assertion. Please tell whether it is truthy according to the screens startTime := time.Now() resp, err := a.model.Generate(a.ctx, a.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", a.config.Model).Msg("call model service for assertion") + Str("model", string(a.modelType)).Msg("call model service for assertion") if err != nil { return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } @@ -154,78 +154,36 @@ Here is the assertion. Please tell whether it is truthy according to the screens return result, nil } -// parseAssertionResult 解析模型返回的JSON响应 +// validateAssertionInput validates the input parameters for assertion +func validateAssertionInput(opts *AssertOptions) error { + if opts.Assertion == "" { + return errors.Wrap(code.LLMPrepareRequestError, "assertion text is required") + } + if opts.Screenshot == "" { + return errors.Wrap(code.LLMPrepareRequestError, "screenshot is required") + } + return nil +} + +// parseAssertionResult parses the model response into AssertionResponse func parseAssertionResult(content string) (*AssertionResponse, error) { - // 1. 从响应中提取JSON内容 + // Extract JSON content from response jsonContent := extractJSON(content) if jsonContent == "" { return nil, errors.New("could not extract JSON from response") } - // 2. 预处理和标准解析尝试 - jsonContent = prepareJSON(jsonContent) + // Parse JSON response var result AssertionResponse - if err := json.Unmarshal([]byte(jsonContent), &result); err == nil { - return &result, nil + if err := json.Unmarshal([]byte(jsonContent), &result); err != nil { + return nil, errors.Wrap(code.LLMParseAssertionResponseError, err.Error()) } - // 3. 备用:正则表达式解析 - if pass, thought := extractWithRegex(jsonContent); thought != "" { - return &AssertionResponse{Pass: pass, Thought: thought}, nil - } - - return nil, errors.New("failed to parse assertion result") -} - -// prepareJSON 预处理JSON字符串,修复常见问题 -func prepareJSON(jsonStr string) string { - // 1. 去除可能的外层引号 - jsonStr = strings.TrimSpace(jsonStr) - if strings.HasPrefix(jsonStr, "\"") && strings.HasSuffix(jsonStr, "\"") { - jsonStr = jsonStr[1 : len(jsonStr)-1] - } - - // 2. 转义thought内容中的引号 - thoughtRegex := regexp.MustCompile(`"thought":\s*"([^"]*)"`) - matches := thoughtRegex.FindStringSubmatch(jsonStr) - if len(matches) > 1 { - thoughtValue := matches[1] - fixedThought := strings.ReplaceAll(thoughtValue, "\"", "\\\"") - jsonStr = strings.Replace(jsonStr, matches[0], fmt.Sprintf(`"thought": "%s"`, fixedThought), 1) - } - - // 3. 处理换行和特殊字符 - jsonStr = strings.ReplaceAll(jsonStr, "\n", "\\n") - jsonStr = strings.ReplaceAll(jsonStr, "\r", "\\r") - jsonStr = strings.ReplaceAll(jsonStr, "\t", "\\t") - - return jsonStr -} - -// extractWithRegex 使用正则表达式提取pass和thought值 -func extractWithRegex(jsonStr string) (pass bool, thought string) { - // 提取pass值 - passRegex := regexp.MustCompile(`"pass":\s*(true|false)`) - passMatches := passRegex.FindStringSubmatch(jsonStr) - - // 提取thought值 - thoughtRegex := regexp.MustCompile(`"thought":\s*"([^"]*(?:"[^"]*)*)"`) - thoughtMatches := thoughtRegex.FindStringSubmatch(jsonStr) - - if len(passMatches) > 1 && len(thoughtMatches) > 1 { - // 处理提取的值 - pass = passMatches[1] == "true" - thought = strings.ReplaceAll(thoughtMatches[1], "\\\"", "\"") - thought = strings.ReplaceAll(thought, "\\\\", "\\") - return pass, thought - } - - return false, "" + return &result, nil } // extractJSON extracts JSON content from a string that might contain markdown or other formatting func extractJSON(content string) string { - // Try to extract JSON directly content = strings.TrimSpace(content) // If the content is already a valid JSON, return it @@ -233,7 +191,7 @@ func extractJSON(content string) string { return content } - // Check for markdown code blocks with more flexible pattern + // Try to extract JSON from markdown code blocks jsonRegex := regexp.MustCompile(`(?:json)?\s*({[\s\S]*?})\s*`) matches := jsonRegex.FindStringSubmatch(content) if len(matches) > 1 { @@ -241,7 +199,6 @@ func extractJSON(content string) string { } // Try a more robust approach for JSON with Chinese characters - // First look for the outermost pair of curly braces startIdx := strings.Index(content, "{") if startIdx >= 0 { depth := 1 @@ -251,19 +208,11 @@ func extractJSON(content string) string { } else if content[i] == '}' { depth-- if depth == 0 { - // Found the closing brace return content[startIdx : i+1] } } } } - // Fallback to regex approach - braceRegex := regexp.MustCompile(`{[\s\S]*?}`) - matches = braceRegex.FindStringSubmatch(content) - if len(matches) > 0 { - return strings.TrimSpace(matches[0]) - } - - return "" + return content } diff --git a/uixt/ai/asserter_prompts.go b/uixt/ai/asserter_prompts.go new file mode 100644 index 00000000..9ceb092d --- /dev/null +++ b/uixt/ai/asserter_prompts.go @@ -0,0 +1,25 @@ +package ai + +// Default assertion system prompt +const defaultAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy.` + +// Default assertion response format +const defaultAssertionResponseJsonFormat = `Return in the following JSON format: +{ + pass: boolean, // whether the assertion is truthy + thought: string | null, // string, if the result is falsy, give the reason why it is falsy. Otherwise, put null. +}` + +// UI-TARS assertion response format +const uiTarsAssertionResponseFormat = `## Output Json String Format +` + "```" + ` +"{ + "pass": <>, + "thought": "<>" +}" +` + "```" + ` + +## Rules **MUST** follow +- Make sure to return **only** the JSON, with **no additional** text or explanations. +- Use Chinese in ` + "`Thought`" + ` part. +- You **MUST** strictly follow up the **Output Json String Format**.` diff --git a/uixt/ai/planner.go b/uixt/ai/planner.go index 027c7d3b..39435570 100644 --- a/uixt/ai/planner.go +++ b/uixt/ai/planner.go @@ -16,7 +16,6 @@ import ( "github.com/httprunner/httprunner/v5/code" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/pkg/errors" - "github.com/rs/zerolog/log" ) type IPlanner interface { @@ -85,100 +84,6 @@ func validatePlanningInput(opts *PlanningOptions) error { return nil } -func logRequest(messages []*schema.Message) { - msgs := make([]*schema.Message, 0, len(messages)) - for _, message := range messages { - msg := &schema.Message{ - Role: message.Role, - } - if message.Content != "" { - msg.Content = message.Content - } else if len(message.MultiContent) > 0 { - for _, mc := range message.MultiContent { - switch mc.Type { - case schema.ChatMessagePartTypeImageURL: - // Create a copy of the ImageURL to avoid modifying the original message - imageURLCopy := *mc.ImageURL - if strings.HasPrefix(imageURLCopy.URL, "data:image/") { - imageURLCopy.URL = "" - } - msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{ - Type: mc.Type, - ImageURL: &imageURLCopy, - }) - } - } - } - msgs = append(msgs, msg) - } - log.Debug().Interface("messages", msgs).Msg("log request messages") -} - -func logResponse(resp *schema.Message) { - logger := log.Info().Str("role", string(resp.Role)). - Str("content", resp.Content) - if resp.ResponseMeta != nil { - logger = logger.Interface("response_meta", resp.ResponseMeta) - } - if resp.Extra != nil { - logger = logger.Interface("extra", resp.Extra) - } - logger.Msg("log response message") -} - -type ConversationHistory []*schema.Message - -// Append adds a message to the conversation history -func (h *ConversationHistory) Append(msg *schema.Message) { - // for user image message: - // - keep at most 4 user image messages - // - delete the oldest user image message when the limit is reached - if msg.Role == schema.User { - // get all existing user messages - userImgCount := 0 - firstUserImgIndex := -1 - - // calculate the number of user messages and find the index of the first user message - for i, item := range *h { - if item.Role == schema.User { - userImgCount++ - if firstUserImgIndex == -1 { - firstUserImgIndex = i - } - } - } - - // if there are already 4 user messages, delete the first one before adding the new message - if userImgCount >= 4 && firstUserImgIndex >= 0 { - // delete the first user message - *h = append( - (*h)[:firstUserImgIndex], - (*h)[firstUserImgIndex+1:]..., - ) - } - // add the new user message to the history - *h = append(*h, msg) - } - - // for assistant message: - // - keep at most the last 10 assistant messages - if msg.Role == schema.Assistant { - // add the new assistant message to the history - *h = append(*h, msg) - - // if there are more than 10 assistant messages, remove the oldest ones - assistantMsgCount := 0 - for i := len(*h) - 1; i >= 0; i-- { - if (*h)[i].Role == schema.Assistant { - assistantMsgCount++ - if assistantMsgCount > 10 { - *h = append((*h)[:i], (*h)[i+1:]...) - } - } - } - } -} - // SavePositionImg saves an image with position markers func SavePositionImg(params struct { InputImgBase64 string diff --git a/uixt/ai/planner_gpt.go b/uixt/ai/planner_gpt.go index 60a95994..d4527201 100644 --- a/uixt/ai/planner_gpt.go +++ b/uixt/ai/planner_gpt.go @@ -4,90 +4,20 @@ import ( "context" "fmt" _ "image/jpeg" - "os" "strings" "time" "github.com/cloudwego/eino-ext/components/model/openai" - openai2 "github.com/cloudwego/eino-ext/libs/acl/openai" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" - "github.com/getkin/kin-openapi/openapi3gen" "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/httprunner/httprunner/v5/code" - "github.com/httprunner/httprunner/v5/internal/config" "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/types" ) -const ( - EnvOpenAIBaseURL = "OPENAI_BASE_URL" - EnvOpenAIAPIKey = "OPENAI_API_KEY" - EnvModelName = "LLM_MODEL_NAME" -) - -// GetOpenAIModelConfig get OpenAI config -func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { - if err := config.LoadEnv(); err != nil { - return nil, errors.Wrap(code.LoadEnvError, err.Error()) - } - - openaiBaseURL := os.Getenv(EnvOpenAIBaseURL) - if openaiBaseURL == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvOpenAIBaseURL) - } - openaiAPIKey := os.Getenv(EnvOpenAIAPIKey) - if openaiAPIKey == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvOpenAIAPIKey) - } - modelName := os.Getenv(EnvModelName) - if modelName == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvModelName) - } - - type OutputFormat struct { - Thought string `json:"thought"` - Action string `json:"action"` - Error string `json:"error,omitempty"` - } - outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil) - if err != nil { - return nil, err - } - - modelConfig := &openai.ChatModelConfig{ - BaseURL: openaiBaseURL, - APIKey: openaiAPIKey, - Model: modelName, - Timeout: defaultTimeout, - // set structured response format - // https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go - ResponseFormat: &openai2.ChatCompletionResponseFormat{ - Type: openai2.ChatCompletionResponseFormatTypeJSONSchema, - JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{ - Name: "thought_and_action", - Description: "data that describes planning thought and action", - Schema: outputFormatSchema.Value, - Strict: false, - }, - }, - } - - // log config info - log.Info().Str("model", modelConfig.Model). - Str("baseURL", modelConfig.BaseURL). - Str("apiKey", maskAPIKey(modelConfig.APIKey)). - Str("timeout", defaultTimeout.String()). - Msg("get model config") - - return modelConfig, nil -} - func NewPlanner(ctx context.Context) (*Planner, error) { config, err := GetOpenAIModelConfig() if err != nil { @@ -99,8 +29,8 @@ func NewPlanner(ctx context.Context) (*Planner, error) { } return &Planner{ ctx: ctx, - config: config, model: model, + modelType: LLMServiceTypeGPT4o, systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling }, nil } @@ -108,8 +38,8 @@ func NewPlanner(ctx context.Context) (*Planner, error) { type Planner struct { ctx context.Context model model.ToolCallingChatModel - config *openai.ChatModelConfig systemPrompt string + modelType LLMServiceType history ConversationHistory } @@ -139,7 +69,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { startTime := time.Now() resp, err := p.model.Generate(p.ctx, p.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", p.config.Model).Msg("call model service") + Str("model", string(p.modelType)).Msg("call model service") if err != nil { return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } diff --git a/uixt/ai/planner_prompts.go b/uixt/ai/planner_prompts.go new file mode 100644 index 00000000..d303c87f --- /dev/null +++ b/uixt/ai/planner_prompts.go @@ -0,0 +1,29 @@ +package ai + +// https://www.volcengine.com/docs/82379/1536429 +const uiTarsPlanningPrompt = ` +You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +` + "```" + ` +Thought: ... +Action: ... +` + "```" + ` + +## Action Space +click(start_box='[x1, y1, x2, y2]') +left_double(start_box='[x1, y1, x2, y2]') +right_single(start_box='[x1, y1, x2, y2]') +drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]') +hotkey(key='') +type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `. +scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + +## Note +- Use Chinese in ` + "`Thought`" + ` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part. + +## User Instruction +` diff --git a/uixt/ai/planner_ui_tars.go b/uixt/ai/planner_ui_tars.go index 57f96d59..83184233 100644 --- a/uixt/ai/planner_ui_tars.go +++ b/uixt/ai/planner_ui_tars.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "math" - "os" "regexp" "strconv" "strings" @@ -14,64 +13,12 @@ import ( "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" "github.com/httprunner/httprunner/v5/code" - "github.com/httprunner/httprunner/v5/internal/config" "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/pkg/errors" "github.com/rs/zerolog/log" ) -const ( - EnvArkBaseURL = "ARK_BASE_URL" - EnvArkAPIKey = "ARK_API_KEY" - EnvArkModelID = "ARK_MODEL_ID" -) - -func GetArkModelConfig() (*ark.ChatModelConfig, error) { - if err := config.LoadEnv(); err != nil { - return nil, errors.Wrap(code.LoadEnvError, err.Error()) - } - - arkBaseURL := os.Getenv(EnvArkBaseURL) - arkAPIKey := os.Getenv(EnvArkAPIKey) - if arkAPIKey == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvArkAPIKey) - } - modelName := os.Getenv(EnvArkModelID) - if modelName == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvArkModelID) - } - timeout := defaultTimeout - - // https://www.volcengine.com/docs/82379/1494384?redirect=1 - temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。 - // topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。 - // maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。 - // frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。 - - modelConfig := &ark.ChatModelConfig{ - BaseURL: arkBaseURL, - APIKey: arkAPIKey, - Model: modelName, - Timeout: &timeout, - Temperature: &temperature, - // TopP: &topP, - // MaxTokens: &maxTokens, - // FrequencyPenalty: &frequencyPenalty, - } - - // log config info - log.Info().Str("model", modelConfig.Model). - Str("baseURL", modelConfig.BaseURL). - Str("apiKey", maskAPIKey(modelConfig.APIKey)). - Str("timeout", defaultTimeout.String()). - Msg("get model config") - - return modelConfig, nil -} - func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) { config, err := GetArkModelConfig() if err != nil { @@ -84,45 +31,17 @@ func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) { return &UITarsPlanner{ ctx: ctx, - config: config, model: chatModel, + modelType: LLMServiceTypeUITARS, systemPrompt: uiTarsPlanningPrompt, }, nil } -// https://www.volcengine.com/docs/82379/1536429 -const uiTarsPlanningPrompt = ` -You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. - -## Output Format -` + "```" + ` -Thought: ... -Action: ... -` + "```" + ` - -## Action Space -click(start_box='[x1, y1, x2, y2]') -left_double(start_box='[x1, y1, x2, y2]') -right_single(start_box='[x1, y1, x2, y2]') -drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]') -hotkey(key='') -type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `. -scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left') -wait() #Sleep for 5s and take a screenshot to check for any changes. -finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. - -## Note -- Use Chinese in ` + "`Thought`" + ` part. -- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part. - -## User Instruction -` - type UITarsPlanner struct { ctx context.Context model model.ToolCallingChatModel - config *ark.ChatModelConfig systemPrompt string + modelType LLMServiceType history ConversationHistory } @@ -152,7 +71,7 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) { startTime := time.Now() resp, err := p.model.Generate(p.ctx, p.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", p.config.Model).Msg("call model service") + Str("model", string(p.modelType)).Msg("call model service") if err != nil { return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } diff --git a/uixt/ai/session.go b/uixt/ai/session.go new file mode 100644 index 00000000..659ccc88 --- /dev/null +++ b/uixt/ai/session.go @@ -0,0 +1,103 @@ +package ai + +import ( + "strings" + + "github.com/cloudwego/eino/schema" + "github.com/rs/zerolog/log" +) + +// ConversationHistory represents a sequence of chat messages +type ConversationHistory []*schema.Message + +// Append adds a new message to the conversation history +func (h *ConversationHistory) Append(msg *schema.Message) { + // for user image message: + // - keep at most 4 user image messages + // - delete the oldest user image message when the limit is reached + if msg.Role == schema.User { + // get all existing user messages + userImgCount := 0 + firstUserImgIndex := -1 + + // calculate the number of user messages and find the index of the first user message + for i, item := range *h { + if item.Role == schema.User { + userImgCount++ + if firstUserImgIndex == -1 { + firstUserImgIndex = i + } + } + } + + // if there are already 4 user messages, delete the first one before adding the new message + if userImgCount >= 4 && firstUserImgIndex >= 0 { + // delete the first user message + *h = append( + (*h)[:firstUserImgIndex], + (*h)[firstUserImgIndex+1:]..., + ) + } + // add the new user message to the history + *h = append(*h, msg) + } + + // for assistant message: + // - keep at most the last 10 assistant messages + if msg.Role == schema.Assistant { + // add the new assistant message to the history + *h = append(*h, msg) + + // if there are more than 10 assistant messages, remove the oldest ones + assistantMsgCount := 0 + for i := len(*h) - 1; i >= 0; i-- { + if (*h)[i].Role == schema.Assistant { + assistantMsgCount++ + if assistantMsgCount > 10 { + *h = append((*h)[:i], (*h)[i+1:]...) + } + } + } + } +} + +func logRequest(messages ConversationHistory) { + msgs := make(ConversationHistory, 0, len(messages)) + for _, message := range messages { + msg := &schema.Message{ + Role: message.Role, + } + if message.Content != "" { + msg.Content = message.Content + } else if len(message.MultiContent) > 0 { + for _, mc := range message.MultiContent { + switch mc.Type { + case schema.ChatMessagePartTypeImageURL: + // Create a copy of the ImageURL to avoid modifying the original message + imageURLCopy := *mc.ImageURL + if strings.HasPrefix(imageURLCopy.URL, "data:image/") { + imageURLCopy.URL = "" + } + msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{ + Type: mc.Type, + ImageURL: &imageURLCopy, + }) + } + } + } + msgs = append(msgs, msg) + } + log.Debug().Interface("messages", msgs).Msg("log request messages") +} + +func logResponse(resp *schema.Message) { + logger := log.Info().Str("role", string(resp.Role)). + Str("content", resp.Content) + if resp.ResponseMeta != nil { + logger = logger.Interface("response_meta", resp.ResponseMeta) + } + if resp.Extra != nil { + logger = logger.Interface("extra", resp.Extra) + } + logger.Msg("log response message") +} diff --git a/uixt/ai/testdata/llk_4.png b/uixt/ai/testdata/llk_4.png new file mode 100644 index 00000000..35a05aaf Binary files /dev/null and b/uixt/ai/testdata/llk_4.png differ