diff --git a/README.en.md b/README.en.md index b8b2c79f..6a72ae73 100644 --- a/README.en.md +++ b/README.en.md @@ -18,6 +18,9 @@ Compared to other UI automation frameworks, HttpRunner's main features include: - Unified API across multiple platforms, reducing learning and horizontal expansion costs - Embracing the open-source ecosystem, fully reusing open-source components +> [HttpRunner v5 用户指南(更新中)](https://debugtalk.feishu.cn/wiki/RqGuw17bsizGTik9WuNcGQyhnaf) +> [HttpRunner DeepWiki](https://deepwiki.com/httprunner/httprunner) + ## Usage ```text $ hrp -h diff --git a/README.md b/README.md index c8a1b4a6..b24d2306 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ HttpRunner 相比其它 UI 自动化框架,主要特点包括: - 多端统一 API,降低学习和横向拓展的成本 - 拥抱开源生态,充分复用开源组件 -[HttpRunner v5 用户指南(更新中)](https://debugtalk.feishu.cn/wiki/RqGuw17bsizGTik9WuNcGQyhnaf) +> [HttpRunner v5 用户指南(更新中)](https://debugtalk.feishu.cn/wiki/RqGuw17bsizGTik9WuNcGQyhnaf) +> [HttpRunner DeepWiki](https://deepwiki.com/httprunner/httprunner) ## 使用说明 diff --git a/code/code.go b/code/code.go index a7205540..7f79326c 100644 --- a/code/code.go +++ b/code/code.go @@ -96,16 +96,15 @@ var ( LoopActionNotFoundError = errors.New("loop action not found error") // 79 ) -// AI related: [80, 90) +// CV related: [80, 90) var ( - CVEnvMissedError = errors.New("CV env missed error") // 80 - CVRequestError = errors.New("CV prepare request error") // 81 - CVServiceConnectionError = errors.New("CV service connect error") // 82 - CVResponseError = errors.New("CV parse response error") // 83 - CVResultNotFoundError = errors.New("CV result not found") // 84 + CVEnvMissedError = errors.New("CV env missed error") // 80 + CVPrepareRequestError = errors.New("CV prepare request error") // 81 + CVRequestServiceError = errors.New("CV request service error") // 82 + CVParseResponseError = errors.New("CV parse response error") // 83 + CVResultNotFoundError = errors.New("CV result not found") // 84 - LLMEnvMissedError = errors.New("LLM env missed error") // 85 - StateUnknowError = errors.New("detect state failed") // 89 + StateUnknowError = errors.New("detect state failed") // 85 ) // trackings related: [90, 100) @@ -121,6 +120,15 @@ var ( RiskControlAccountActivation = errors.New("risk control account activation") // 102 ) +// LLM related: [110, 120) +var ( + LLMEnvMissedError = errors.New("missed LLM env error") // 110 + LLMPrepareRequestError = errors.New("prepare LLM request error") // 111 + LLMRequestServiceError = errors.New("request LLM service error") // 112 + LLMParsePlanningResponseError = errors.New("parse LLM planning response error") // 113 + LLMParseAssertionResponseError = errors.New("parse LLM assertion response error") // 114 +) + var errorsMap = map[error]int{ // environment ConfigureError: 3, @@ -194,14 +202,21 @@ var errorsMap = map[error]int{ MobileUIPopupError: 78, LoopActionNotFoundError: 79, - // AI related - CVEnvMissedError: 80, - CVRequestError: 81, - CVServiceConnectionError: 82, - CVResponseError: 83, - CVResultNotFoundError: 84, - LLMEnvMissedError: 85, - StateUnknowError: 89, + // CV related + CVEnvMissedError: 80, + CVPrepareRequestError: 81, + CVRequestServiceError: 82, + CVParseResponseError: 83, + CVResultNotFoundError: 84, + + StateUnknowError: 85, + + // LLM related + LLMEnvMissedError: 110, + LLMPrepareRequestError: 111, + LLMRequestServiceError: 112, + LLMParsePlanningResponseError: 113, + LLMParseAssertionResponseError: 114, // trackings related TrackingGetError: 90, diff --git a/config.go b/config.go index 3eaf08a6..5c0506d3 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/httprunner/httprunner/v5/internal/builtin" + "github.com/httprunner/httprunner/v5/uixt/ai" "github.com/httprunner/httprunner/v5/uixt/option" ) @@ -42,6 +43,8 @@ type TConfig struct { Path string `json:"path,omitempty" yaml:"path,omitempty"` // testcase file path PluginSetting *PluginConfig `json:"plugin,omitempty" yaml:"plugin,omitempty"` // plugin config IgnorePopup bool `json:"ignore_popup,omitempty" yaml:"ignore_popup,omitempty"` + LLMService ai.LLMServiceType `json:"llm_service,omitempty" yaml:"llm_service,omitempty"` + CVService ai.CVServiceType `json:"cv_service,omitempty" yaml:"cv_service,omitempty"` } func (c *TConfig) Get() *TConfig { @@ -108,6 +111,18 @@ func (c *TConfig) SetWeight(weight int) *TConfig { return c } +// SetLLMService sets LLM service for current testcase. +func (c *TConfig) SetLLMService(llmService ai.LLMServiceType) *TConfig { + c.LLMService = llmService + return c +} + +// SetCVService sets CV service for current testcase. +func (c *TConfig) SetCVService(cvService ai.CVServiceType) *TConfig { + c.CVService = cvService + return c +} + func (c *TConfig) SetWebSocket(times, interval, timeout, size int64) *TConfig { c.WebSocketSetting = &WebSocketConfig{ ReconnectionTimes: times, diff --git a/convert/from_curl.go b/convert/from_curl.go index 4ad04d90..8daa46ea 100644 --- a/convert/from_curl.go +++ b/convert/from_curl.go @@ -119,7 +119,7 @@ func LoadCurlCase(path string) (*hrp.TestCaseDef, error) { } func readFileLines(path string) ([]string, error) { - file, err := os.Open(path) + file, err := os.OpenFile(path, os.O_RDONLY, 0o600) if err != nil { log.Error().Err(err).Str("path", path).Msg("open file failed") return nil, err diff --git a/internal/version/VERSION b/internal/version/VERSION index ee7d2483..13a7f720 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2504271150 +v5.0.0-beta-2504292008 diff --git a/pkg/gadb/device.go b/pkg/gadb/device.go index 75bb376a..f4215a58 100644 --- a/pkg/gadb/device.go +++ b/pkg/gadb/device.go @@ -536,7 +536,7 @@ func (d *Device) List(remotePath string) (devFileInfos []DeviceFileInfo, err err } func (d *Device) PushFile(localPath, remotePath string, modification ...time.Time) (err error) { - localFile, err := os.Open(localPath) + localFile, err := os.OpenFile(localPath, os.O_RDONLY, 0o600) if err != nil { return err } @@ -645,7 +645,7 @@ func (d *Device) installViaABBExec(apk io.ReadSeeker, args ...string) (raw []byt } func (d *Device) InstallAPK(apkPath string, args ...string) (string, error) { - apkFile, err := os.Open(apkPath) + apkFile, err := os.OpenFile(apkPath, os.O_RDONLY, 0o600) if err != nil { return "", errors.Wrap(err, fmt.Sprintf("open apk file %s failed", apkPath)) } diff --git a/runner.go b/runner.go index 69d3bf18..607b0de4 100644 --- a/runner.go +++ b/runner.go @@ -418,6 +418,17 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) { } r.parametersIterator = parametersIterator + // ai options + aiOpts := []ai.AIServiceOption{} + if parsedConfig.LLMService != "" { + aiOpts = append(aiOpts, ai.WithLLMService(parsedConfig.LLMService)) + } + if parsedConfig.CVService == "" { + // default to vedem + parsedConfig.CVService = ai.CVServiceTypeVEDEM + } + aiOpts = append(aiOpts, ai.WithCVService(parsedConfig.CVService)) + // parse android devices config for _, androidDeviceOptions := range parsedConfig.Android { err := r.parseDeviceConfig(androidDeviceOptions, parsedConfig.Variables) @@ -435,7 +446,7 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) { return nil, errors.Wrap(err, "init android driver failed") } - driverExt := uixt.NewXTDriver(driver, ai.WithCVService(ai.CVServiceTypeVEDEM)) + driverExt := uixt.NewXTDriver(driver, aiOpts...) r.uixtDrivers[androidDeviceOptions.SerialNumber] = driverExt } // parse iOS devices config @@ -455,7 +466,7 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) { return nil, errors.Wrap(err, "init ios driver failed") } - driverExt := uixt.NewXTDriver(driver, ai.WithCVService(ai.CVServiceTypeVEDEM)) + driverExt := uixt.NewXTDriver(driver, aiOpts...) r.uixtDrivers[iosDeviceOptions.UDID] = driverExt } // parse harmony devices config @@ -475,7 +486,7 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) { return nil, errors.Wrap(err, "init harmony driver failed") } - driverExt := uixt.NewXTDriver(driver, ai.WithCVService(ai.CVServiceTypeVEDEM)) + driverExt := uixt.NewXTDriver(driver, aiOpts...) r.uixtDrivers[harmonyDeviceOptions.ConnectKey] = driverExt } diff --git a/summary.go b/summary.go index 0cac3117..ab6892a7 100644 --- a/summary.go +++ b/summary.go @@ -95,7 +95,7 @@ func (s *Summary) GenHTMLReport() error { } reportPath := filepath.Join(reportsDir, "report.html") - file, err := os.OpenFile(reportPath, os.O_WRONLY|os.O_CREATE, 0o666) + file, err := os.OpenFile(reportPath, os.O_WRONLY|os.O_CREATE, 0o600) if err != nil { log.Error().Err(err).Msg("open file failed") return err diff --git a/uixt/ai/ai.go b/uixt/ai/ai.go index c29ce677..3d039820 100644 --- a/uixt/ai/ai.go +++ b/uixt/ai/ai.go @@ -49,26 +49,60 @@ type LLMServiceType string const ( LLMServiceTypeUITARS LLMServiceType = "ui-tars" LLMServiceTypeGPT4o LLMServiceType = "gpt-4o" + LLMServiceTypeGPT4Vision LLMServiceType = "gpt-4-vision" + LLMServiceTypeQwenVL LLMServiceType = "qwen-vl" LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3" ) -func WithLLMService(service LLMServiceType) AIServiceOption { +// ILLMService 定义了 LLM 服务接口,包括规划和断言功能 +type ILLMService interface { + Call(opts *PlanningOptions) (*PlanningResult, error) + Assert(opts *AssertOptions) (*AssertionResponse, error) +} + +func WithLLMService(modelType LLMServiceType) AIServiceOption { return func(opts *AIServices) { - if service == LLMServiceTypeGPT4o { - var err error - opts.ILLMService, err = NewPlanner(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init gpt-4o llm service failed") - os.Exit(code.GetErrorCode(err)) - } + // init planner + var planner IPlanner + var err error + switch modelType { + case LLMServiceTypeGPT4o: + // TODO: implement gpt-4o planner and asserter + planner, err = NewPlanner(context.Background()) + case LLMServiceTypeUITARS: + planner, err = NewUITarsPlanner(context.Background()) } - if service == LLMServiceTypeUITARS { - var err error - opts.ILLMService, err = NewUITarsPlanner(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init ui-tars llm service failed") - os.Exit(code.GetErrorCode(err)) - } + if err != nil { + log.Error().Err(err).Msgf("init %s planner failed", modelType) + os.Exit(code.GetErrorCode(err)) + } + + // init asserter + asserter, err := NewAsserter(context.Background(), modelType) + if err != nil { + log.Error().Err(err).Msgf("init %s asserter failed", modelType) + os.Exit(code.GetErrorCode(err)) + } + + opts.ILLMService = &combinedLLMService{ + planner: planner, + asserter: asserter, } } } + +// combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能 +type combinedLLMService struct { + planner IPlanner // 提供规划功能 + asserter IAsserter // 提供断言功能 +} + +// Call 执行规划功能 +func (c *combinedLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) { + return c.planner.Call(opts) +} + +// Assert 执行断言功能 +func (c *combinedLLMService) Assert(opts *AssertOptions) (*AssertionResponse, error) { + return c.asserter.Assert(opts) +} diff --git a/uixt/ai/ai_ark.go b/uixt/ai/ai_ark.go new file mode 100644 index 00000000..265d4afb --- /dev/null +++ b/uixt/ai/ai_ark.go @@ -0,0 +1,62 @@ +package ai + +import ( + "os" + + "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/internal/config" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +const ( + EnvArkBaseURL = "ARK_BASE_URL" + EnvArkAPIKey = "ARK_API_KEY" + EnvArkModelID = "ARK_MODEL_ID" +) + +func GetArkModelConfig() (*ark.ChatModelConfig, error) { + if err := config.LoadEnv(); err != nil { + return nil, errors.Wrap(code.LoadEnvError, err.Error()) + } + + arkBaseURL := os.Getenv(EnvArkBaseURL) + arkAPIKey := os.Getenv(EnvArkAPIKey) + if arkAPIKey == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvArkAPIKey) + } + modelName := os.Getenv(EnvArkModelID) + if modelName == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvArkModelID) + } + timeout := defaultTimeout + + // https://www.volcengine.com/docs/82379/1494384?redirect=1 + temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。 + // topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。 + // maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。 + // frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。 + + modelConfig := &ark.ChatModelConfig{ + BaseURL: arkBaseURL, + APIKey: arkAPIKey, + Model: modelName, + Timeout: &timeout, + Temperature: &temperature, + // TopP: &topP, + // MaxTokens: &maxTokens, + // FrequencyPenalty: &frequencyPenalty, + } + + // log config info + log.Info().Str("model", modelConfig.Model). + Str("baseURL", modelConfig.BaseURL). + Str("apiKey", maskAPIKey(modelConfig.APIKey)). + Str("timeout", defaultTimeout.String()). + Msg("get model config") + + return modelConfig, nil +} diff --git a/uixt/ai/ai_openai.go b/uixt/ai/ai_openai.go new file mode 100644 index 00000000..85d35ab4 --- /dev/null +++ b/uixt/ai/ai_openai.go @@ -0,0 +1,79 @@ +package ai + +import ( + "os" + + "github.com/cloudwego/eino-ext/components/model/openai" + openai2 "github.com/cloudwego/eino-ext/libs/acl/openai" + "github.com/getkin/kin-openapi/openapi3gen" + "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/internal/config" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +const ( + EnvOpenAIBaseURL = "OPENAI_BASE_URL" + EnvOpenAIAPIKey = "OPENAI_API_KEY" + EnvModelName = "LLM_MODEL_NAME" +) + +// GetOpenAIModelConfig get OpenAI config +func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { + if err := config.LoadEnv(); err != nil { + return nil, errors.Wrap(code.LoadEnvError, err.Error()) + } + + openaiBaseURL := os.Getenv(EnvOpenAIBaseURL) + if openaiBaseURL == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvOpenAIBaseURL) + } + openaiAPIKey := os.Getenv(EnvOpenAIAPIKey) + if openaiAPIKey == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvOpenAIAPIKey) + } + modelName := os.Getenv(EnvModelName) + if modelName == "" { + return nil, errors.Wrapf(code.LLMEnvMissedError, + "env %s missed", EnvModelName) + } + + type OutputFormat struct { + Thought string `json:"thought"` + Action string `json:"action"` + Error string `json:"error,omitempty"` + } + outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil) + if err != nil { + return nil, err + } + + modelConfig := &openai.ChatModelConfig{ + BaseURL: openaiBaseURL, + APIKey: openaiAPIKey, + Model: modelName, + Timeout: defaultTimeout, + // set structured response format + // https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go + ResponseFormat: &openai2.ChatCompletionResponseFormat{ + Type: openai2.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{ + Name: "thought_and_action", + Description: "data that describes planning thought and action", + Schema: outputFormatSchema.Value, + Strict: false, + }, + }, + } + + // log config info + log.Info().Str("model", modelConfig.Model). + Str("baseURL", modelConfig.BaseURL). + Str("apiKey", maskAPIKey(modelConfig.APIKey)). + Str("timeout", defaultTimeout.String()). + Msg("get model config") + + return modelConfig, nil +} diff --git a/uixt/ai/asserter.go b/uixt/ai/asserter.go new file mode 100644 index 00000000..9e0d628b --- /dev/null +++ b/uixt/ai/asserter.go @@ -0,0 +1,218 @@ +package ai + +import ( + "context" + "fmt" + "regexp" + "strings" + "time" + + "github.com/cloudwego/eino-ext/components/model/ark" + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/internal/json" + "github.com/httprunner/httprunner/v5/uixt/types" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +// IAsserter interface defines the contract for assertion operations +type IAsserter interface { + Assert(opts *AssertOptions) (*AssertionResponse, error) +} + +// AssertOptions represents the input options for assertion +type AssertOptions struct { + Assertion string `json:"assertion"` // The assertion text to verify + Screenshot string `json:"screenshot"` // Base64 encoded screenshot + Size types.Size `json:"size"` // Screen dimensions +} + +// AssertionResponse represents the response from an AI assertion +type AssertionResponse struct { + Pass bool `json:"pass"` + Thought string `json:"thought"` +} + +// Asserter handles assertion using different AI models +type Asserter struct { + ctx context.Context + model model.ToolCallingChatModel + systemPrompt string + history ConversationHistory + modelType LLMServiceType +} + +// NewAsserter creates a new Asserter instance +func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, error) { + asserter := &Asserter{ + ctx: ctx, + modelType: modelType, + systemPrompt: getAssertionSystemPrompt(modelType), + } + + switch modelType { + case LLMServiceTypeUITARS: + config, err := GetArkModelConfig() + if err != nil { + return nil, err + } + asserter.model, err = ark.NewChatModel(ctx, config) + if err != nil { + return nil, err + } + case LLMServiceTypeGPT4Vision, LLMServiceTypeGPT4o: + config, err := GetOpenAIModelConfig() + if err != nil { + return nil, err + } + asserter.model, err = openai.NewChatModel(ctx, config) + if err != nil { + return nil, err + } + default: + return nil, errors.New("not supported model type for asserter") + } + + return asserter, nil +} + +// getAssertionSystemPrompt returns the appropriate system prompt for the given model type +func getAssertionSystemPrompt(modelType LLMServiceType) string { + if modelType == LLMServiceTypeUITARS { + return defaultAssertionPrompt + "\n\n" + uiTarsAssertionResponseFormat + } + return defaultAssertionPrompt + "\n\n" + defaultAssertionResponseJsonFormat +} + +// Assert performs the assertion check on the screenshot +func (a *Asserter) Assert(opts *AssertOptions) (*AssertionResponse, error) { + // Validate input parameters + if err := validateAssertionInput(opts); err != nil { + return nil, errors.Wrap(err, "validate assertion parameters failed") + } + + // Reset history for each new assertion + a.history = ConversationHistory{ + { + Role: schema.System, + Content: a.systemPrompt, + }, + } + + // Create user message with screenshot and assertion + userMsg := &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{ + URL: opts.Screenshot, + Detail: schema.ImageURLDetailAuto, + }, + }, + { + Type: schema.ChatMessagePartTypeText, + Text: fmt.Sprintf(` +Here is the assertion. Please tell whether it is truthy according to the screenshot. +===================================== +%s +===================================== + `, opts.Assertion), + }, + }, + } + + // Append user message to history + a.history.Append(userMsg) + + // Call model service, generate response + logRequest(a.history) + startTime := time.Now() + resp, err := a.model.Generate(a.ctx, a.history) + log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). + Str("model", string(a.modelType)).Msg("call model service for assertion") + if err != nil { + return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) + } + logResponse(resp) + + // Parse result + result, err := parseAssertionResult(resp.Content) + if err != nil { + return nil, errors.Wrap(code.LLMParseAssertionResponseError, err.Error()) + } + + // Append assistant message to history + a.history.Append(&schema.Message{ + Role: schema.Assistant, + Content: resp.Content, + }) + + return result, nil +} + +// validateAssertionInput validates the input parameters for assertion +func validateAssertionInput(opts *AssertOptions) error { + if opts.Assertion == "" { + return errors.Wrap(code.LLMPrepareRequestError, "assertion text is required") + } + if opts.Screenshot == "" { + return errors.Wrap(code.LLMPrepareRequestError, "screenshot is required") + } + return nil +} + +// parseAssertionResult parses the model response into AssertionResponse +func parseAssertionResult(content string) (*AssertionResponse, error) { + // Extract JSON content from response + jsonContent := extractJSON(content) + if jsonContent == "" { + return nil, errors.New("could not extract JSON from response") + } + + // Parse JSON response + var result AssertionResponse + if err := json.Unmarshal([]byte(jsonContent), &result); err != nil { + return nil, errors.Wrap(code.LLMParseAssertionResponseError, err.Error()) + } + + return &result, nil +} + +// extractJSON extracts JSON content from a string that might contain markdown or other formatting +func extractJSON(content string) string { + content = strings.TrimSpace(content) + + // If the content is already a valid JSON, return it + if strings.HasPrefix(content, "{") && strings.HasSuffix(content, "}") { + return content + } + + // Try to extract JSON from markdown code blocks + jsonRegex := regexp.MustCompile(`(?:json)?\s*({[\s\S]*?})\s*`) + matches := jsonRegex.FindStringSubmatch(content) + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + + // Try a more robust approach for JSON with Chinese characters + startIdx := strings.Index(content, "{") + if startIdx >= 0 { + depth := 1 + for i := startIdx + 1; i < len(content); i++ { + if content[i] == '{' { + depth++ + } else if content[i] == '}' { + depth-- + if depth == 0 { + return content[startIdx : i+1] + } + } + } + } + + return content +} diff --git a/uixt/ai/asserter_prompts.go b/uixt/ai/asserter_prompts.go new file mode 100644 index 00000000..9ceb092d --- /dev/null +++ b/uixt/ai/asserter_prompts.go @@ -0,0 +1,25 @@ +package ai + +// Default assertion system prompt +const defaultAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy.` + +// Default assertion response format +const defaultAssertionResponseJsonFormat = `Return in the following JSON format: +{ + pass: boolean, // whether the assertion is truthy + thought: string | null, // string, if the result is falsy, give the reason why it is falsy. Otherwise, put null. +}` + +// UI-TARS assertion response format +const uiTarsAssertionResponseFormat = `## Output Json String Format +` + "```" + ` +"{ + "pass": <>, + "thought": "<>" +}" +` + "```" + ` + +## Rules **MUST** follow +- Make sure to return **only** the JSON, with **no additional** text or explanations. +- Use Chinese in ` + "`Thought`" + ` part. +- You **MUST** strictly follow up the **Output Json String Format**.` diff --git a/uixt/ai/asserter_test.go b/uixt/ai/asserter_test.go new file mode 100644 index 00000000..8786914b --- /dev/null +++ b/uixt/ai/asserter_test.go @@ -0,0 +1,103 @@ +package ai + +import ( + "testing" + + "github.com/httprunner/httprunner/v5/uixt/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createAIService(t *testing.T) *AIServices { + aiService := NewAIService(WithLLMService(LLMServiceTypeUITARS)) + require.NotNil(t, aiService) + require.NotNil(t, aiService.ILLMService) + return aiService +} + +// 测试有效断言 +func TestValidAssertions(t *testing.T) { + aiService := createAIService(t) + + testCases := []struct { + name string + assertion string + imagePath string + expectPass bool + }{ + { + name: "深度思考功能已开启", + assertion: "输入框下方的「深度思考」文字是蓝色的", + imagePath: "testdata/deepseek_think_on.png", + expectPass: true, + }, + { + name: "深度思考功能未开启", + assertion: "输入框下方的「深度思考」文字是灰色的", + imagePath: "testdata/deepseek_think_off.png", + expectPass: true, + }, + { + name: "联网搜索功能已开启", + assertion: "输入框下方的「联网搜索」文字是蓝色的", + imagePath: "testdata/deepseek_network_on.png", + expectPass: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + imageBase64, size, err := loadImage(tc.imagePath) + require.NoError(t, err) + + result, err := aiService.ILLMService.Assert(&AssertOptions{ + Assertion: tc.assertion, + Screenshot: imageBase64, + Size: size, + }) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, tc.expectPass, result.Pass) + assert.NotEmpty(t, result.Thought) + }) + } +} + +// 测试无效参数 +func TestInvalidParameters(t *testing.T) { + aiService := createAIService(t) + testCases := []struct { + name string + assertion string + screenshot string + size types.Size + expectedError string + }{ + { + name: "缺少截图", + assertion: "测试断言", + screenshot: "", + size: types.Size{}, + expectedError: "screenshot is required", + }, + { + name: "缺少断言", + assertion: "", + screenshot: "some-base64-data", + size: types.Size{}, + expectedError: "assertion text is required", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := aiService.ILLMService.Assert(&AssertOptions{ + Assertion: tc.assertion, + Screenshot: tc.screenshot, + Size: tc.size, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedError) + }) + } +} diff --git a/uixt/ai/cv_vedem.go b/uixt/ai/cv_vedem.go index baf5bb21..532ea301 100644 --- a/uixt/ai/cv_vedem.go +++ b/uixt/ai/cv_vedem.go @@ -52,7 +52,7 @@ func (s *vedemCVService) ReadFromPath(imagePath string, opts ...option.ActionOpt imageResult *CVResult, err error) { imageBuf, err := os.ReadFile(imagePath) if err != nil { - err = errors.Wrap(code.CVRequestError, + err = errors.Wrap(code.CVPrepareRequestError, fmt.Sprintf("read image file error: %v", err)) return } @@ -116,21 +116,21 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A formWriter, err := bodyWriter.CreateFormFile("image", "screenshot.png") if err != nil { - err = errors.Wrap(code.CVRequestError, + err = errors.Wrap(code.CVPrepareRequestError, fmt.Sprintf("create form file error: %v", err)) return } size, err := formWriter.Write(imageBuf.Bytes()) if err != nil { - err = errors.Wrap(code.CVRequestError, + err = errors.Wrap(code.CVPrepareRequestError, fmt.Sprintf("write form error: %v", err)) return } err = bodyWriter.Close() if err != nil { - err = errors.Wrap(code.CVRequestError, + err = errors.Wrap(code.CVPrepareRequestError, fmt.Sprintf("close body writer error: %v", err)) return } @@ -146,7 +146,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A req, err = http.NewRequest("POST", os.Getenv("VEDEM_IMAGE_URL"), copiedBodyBuf) if err != nil { - err = errors.Wrap(code.CVRequestError, + err = errors.Wrap(code.CVPrepareRequestError, fmt.Sprintf("construct request error: %v", err)) return } @@ -192,7 +192,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A break } if resp == nil { - err = code.CVServiceConnectionError + err = code.CVRequestServiceError return } @@ -200,13 +200,13 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A results, err := io.ReadAll(resp.Body) if err != nil { - err = errors.Wrap(code.CVResponseError, + err = errors.Wrap(code.CVParseResponseError, fmt.Sprintf("read response body error: %v", err)) return } if resp.StatusCode != http.StatusOK { - err = errors.Wrap(code.CVResponseError, + err = errors.Wrap(code.CVParseResponseError, fmt.Sprintf("unexpected response status code: %d, results: %v", resp.StatusCode, string(results))) return @@ -215,13 +215,13 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A var imageResponse APIResponseImage err = json.Unmarshal(results, &imageResponse) if err != nil { - err = errors.Wrap(code.CVResponseError, + err = errors.Wrap(code.CVParseResponseError, fmt.Sprintf("json unmarshal veDEM image response body error, response=%s", string(results))) return } if imageResponse.Code != 0 { - err = errors.Wrap(code.CVResponseError, + err = errors.Wrap(code.CVParseResponseError, fmt.Sprintf("unexpected response data code: %d, message: %s", imageResponse.Code, imageResponse.Message)) return diff --git a/uixt/ai/planner.go b/uixt/ai/planner.go index b807e633..39435570 100644 --- a/uixt/ai/planner.go +++ b/uixt/ai/planner.go @@ -13,11 +13,12 @@ import ( "time" "github.com/cloudwego/eino/schema" + "github.com/httprunner/httprunner/v5/code" "github.com/httprunner/httprunner/v5/uixt/types" - "github.com/rs/zerolog/log" + "github.com/pkg/errors" ) -type ILLMService interface { +type IPlanner interface { Call(opts *PlanningOptions) (*PlanningResult, error) } @@ -57,23 +58,16 @@ const ( ) const ( - defaultTimeout = 60 * time.Second + defaultTimeout = 30 * time.Second ) -// Error types -var ( - ErrEmptyInstruction = fmt.Errorf("user instruction is empty") - ErrNoConversationHistory = fmt.Errorf("conversation history is empty") - ErrInvalidImageData = fmt.Errorf("invalid image data") -) - -func validateInput(opts *PlanningOptions) error { +func validatePlanningInput(opts *PlanningOptions) error { if opts.UserInstruction == "" { - return ErrEmptyInstruction + return errors.Wrap(code.LLMPrepareRequestError, "user instruction is empty") } - if opts.Message == nil { - return ErrNoConversationHistory + if opts.Message == nil || opts.Message.Role == "" { + return errors.Wrap(code.LLMPrepareRequestError, "user message is empty") } if opts.Message.Role == schema.User { @@ -81,7 +75,7 @@ func validateInput(opts *PlanningOptions) error { if len(opts.Message.MultiContent) > 0 { for _, content := range opts.Message.MultiContent { if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil { - return ErrInvalidImageData + return errors.Wrap(code.LLMPrepareRequestError, "invalid image data") } } } @@ -90,98 +84,6 @@ func validateInput(opts *PlanningOptions) error { return nil } -func logRequest(messages []*schema.Message) { - msgs := make([]*schema.Message, 0, len(messages)) - for _, message := range messages { - msg := &schema.Message{ - Role: message.Role, - } - if message.Content != "" { - msg.Content = message.Content - } else if len(message.MultiContent) > 0 { - for _, mc := range message.MultiContent { - switch mc.Type { - case schema.ChatMessagePartTypeImageURL: - // Create a copy of the ImageURL to avoid modifying the original message - imageURLCopy := *mc.ImageURL - if strings.HasPrefix(imageURLCopy.URL, "data:image/") { - imageURLCopy.URL = "" - } - msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{ - Type: mc.Type, - ImageURL: &imageURLCopy, - }) - } - } - } - msgs = append(msgs, msg) - } - log.Debug().Interface("messages", msgs).Msg("log request messages") -} - -func logResponse(resp *schema.Message) { - logger := log.Info().Str("role", string(resp.Role)). - Str("content", resp.Content) - if resp.ResponseMeta != nil { - logger = logger.Interface("response_meta", resp.ResponseMeta) - } - if resp.Extra != nil { - logger = logger.Interface("extra", resp.Extra) - } - logger.Msg("log response message") -} - -// appendConversationHistory adds a message to the conversation history -func appendConversationHistory(history []*schema.Message, msg *schema.Message) { - // for user image message: - // - keep at most 4 user image messages - // - delete the oldest user image message when the limit is reached - if msg.Role == schema.User { - // get all existing user messages - userImgCount := 0 - firstUserImgIndex := -1 - - // calculate the number of user messages and find the index of the first user message - for i, item := range history { - if item.Role == schema.User { - userImgCount++ - if firstUserImgIndex == -1 { - firstUserImgIndex = i - } - } - } - - // if there are already 4 user messages, delete the first one before adding the new message - if userImgCount >= 4 && firstUserImgIndex >= 0 { - // delete the first user message - history = append( - history[:firstUserImgIndex], - history[firstUserImgIndex+1:]..., - ) - } - // add the new user message to the history - history = append(history, msg) - } - - // for assistant message: - // - keep at most the last 10 assistant messages - if msg.Role == schema.Assistant { - // add the new assistant message to the history - history = append(history, msg) - - // if there are more than 10 assistant messages, remove the oldest ones - assistantMsgCount := 0 - for i := len(history) - 1; i >= 0; i-- { - if history[i].Role == schema.Assistant { - assistantMsgCount++ - if assistantMsgCount > 10 { - history = append(history[:i], history[i+1:]...) - } - } - } - } -} - // SavePositionImg saves an image with position markers func SavePositionImg(params struct { InputImgBase64 string @@ -249,37 +151,6 @@ func SavePositionImg(params struct { return nil } -// loadImage loads image and returns base64 encoded string -func loadImage(imagePath string) (base64Str string, size types.Size, err error) { - // Read the image file - imageFile, err := os.Open(imagePath) - if err != nil { - return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err) - } - defer imageFile.Close() - - // Decode the image to get its resolution - imageData, format, err := image.Decode(imageFile) - if err != nil { - return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err) - } - - // Get the resolution of the image - width := imageData.Bounds().Dx() - height := imageData.Bounds().Dy() - size = types.Size{Width: width, Height: height} - - // Convert image to base64 - buf := new(bytes.Buffer) - if err := png.Encode(buf, imageData); err != nil { - return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err) - } - base64Str = fmt.Sprintf("data:image/%s;base64,%s", format, - base64.StdEncoding.EncodeToString(buf.Bytes())) - - return base64Str, size, nil -} - // maskAPIKey masks the API key func maskAPIKey(key string) string { if len(key) <= 8 { diff --git a/uixt/ai/planner_gpt.go b/uixt/ai/planner_gpt.go index 4f6dcf5f..d4527201 100644 --- a/uixt/ai/planner_gpt.go +++ b/uixt/ai/planner_gpt.go @@ -4,90 +4,20 @@ import ( "context" "fmt" _ "image/jpeg" - "os" "strings" "time" "github.com/cloudwego/eino-ext/components/model/openai" - openai2 "github.com/cloudwego/eino-ext/libs/acl/openai" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" - "github.com/getkin/kin-openapi/openapi3gen" "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/httprunner/httprunner/v5/code" - "github.com/httprunner/httprunner/v5/internal/config" "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/types" ) -const ( - EnvOpenAIBaseURL = "OPENAI_BASE_URL" - EnvOpenAIAPIKey = "OPENAI_API_KEY" - EnvModelName = "LLM_MODEL_NAME" -) - -// GetOpenAIModelConfig get OpenAI config -func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { - if err := config.LoadEnv(); err != nil { - return nil, errors.Wrap(code.LoadEnvError, err.Error()) - } - - openaiBaseURL := os.Getenv(EnvOpenAIBaseURL) - if openaiBaseURL == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvOpenAIBaseURL) - } - openaiAPIKey := os.Getenv(EnvOpenAIAPIKey) - if openaiAPIKey == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvOpenAIAPIKey) - } - modelName := os.Getenv(EnvModelName) - if modelName == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvModelName) - } - - type OutputFormat struct { - Thought string `json:"thought"` - Action string `json:"action"` - Error string `json:"error,omitempty"` - } - outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil) - if err != nil { - return nil, err - } - - modelConfig := &openai.ChatModelConfig{ - BaseURL: openaiBaseURL, - APIKey: openaiAPIKey, - Model: modelName, - Timeout: defaultTimeout, - // set structured response format - // https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go - ResponseFormat: &openai2.ChatCompletionResponseFormat{ - Type: openai2.ChatCompletionResponseFormatTypeJSONSchema, - JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{ - Name: "thought_and_action", - Description: "data that describes planning thought and action", - Schema: outputFormatSchema.Value, - Strict: false, - }, - }, - } - - // log config info - log.Info().Str("model", modelConfig.Model). - Str("baseURL", modelConfig.BaseURL). - Str("apiKey", maskAPIKey(modelConfig.APIKey)). - Str("timeout", defaultTimeout.String()). - Msg("get model config") - - return modelConfig, nil -} - func NewPlanner(ctx context.Context) (*Planner, error) { config, err := GetOpenAIModelConfig() if err != nil { @@ -99,8 +29,8 @@ func NewPlanner(ctx context.Context) (*Planner, error) { } return &Planner{ ctx: ctx, - config: config, model: model, + modelType: LLMServiceTypeGPT4o, systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling }, nil } @@ -108,23 +38,23 @@ func NewPlanner(ctx context.Context) (*Planner, error) { type Planner struct { ctx context.Context model model.ToolCallingChatModel - config *openai.ChatModelConfig systemPrompt string - history []*schema.Message // conversation history + modelType LLMServiceType + history ConversationHistory } // Call performs UI planning using Vision Language Model func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { // validate input parameters - if err := validateInput(opts); err != nil { - return nil, errors.Wrap(err, "validate input parameters failed") + if err := validatePlanningInput(opts); err != nil { + return nil, errors.Wrap(err, "validate planning parameters failed") } // prepare prompt if len(p.history) == 0 { // add system message systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction - p.history = []*schema.Message{ + p.history = ConversationHistory{ { Role: schema.System, Content: systemPrompt, @@ -132,27 +62,27 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { } } // append user image message - appendConversationHistory(p.history, opts.Message) + p.history.Append(opts.Message) // call model service, generate response logRequest(p.history) startTime := time.Now() resp, err := p.model.Generate(p.ctx, p.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", p.config.Model).Msg("call model service") + Str("model", string(p.modelType)).Msg("call model service") if err != nil { - return nil, fmt.Errorf("request model service failed: %w", err) + return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } logResponse(resp) // parse result result, err := p.parseResult(resp, opts.Size) if err != nil { - return nil, errors.Wrap(err, "parse result failed") + return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error()) } // append assistant message - appendConversationHistory(p.history, &schema.Message{ + p.history.Append(&schema.Message{ Role: schema.Assistant, Content: result.ActionSummary, }) diff --git a/uixt/ai/planner_prompts.go b/uixt/ai/planner_prompts.go new file mode 100644 index 00000000..d303c87f --- /dev/null +++ b/uixt/ai/planner_prompts.go @@ -0,0 +1,29 @@ +package ai + +// https://www.volcengine.com/docs/82379/1536429 +const uiTarsPlanningPrompt = ` +You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +` + "```" + ` +Thought: ... +Action: ... +` + "```" + ` + +## Action Space +click(start_box='[x1, y1, x2, y2]') +left_double(start_box='[x1, y1, x2, y2]') +right_single(start_box='[x1, y1, x2, y2]') +drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]') +hotkey(key='') +type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `. +scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + +## Note +- Use Chinese in ` + "`Thought`" + ` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part. + +## User Instruction +` diff --git a/uixt/ai/planner_test.go b/uixt/ai/planner_test.go index 735bee64..68a988b3 100644 --- a/uixt/ai/planner_test.go +++ b/uixt/ai/planner_test.go @@ -1,11 +1,18 @@ package ai import ( + "bytes" "context" + "encoding/base64" + "fmt" + "image" + "image/jpeg" + "image/png" "os" "testing" "github.com/cloudwego/eino/schema" + "github.com/httprunner/httprunner/v5/code" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -195,6 +202,53 @@ func TestChatList(t *testing.T) { require.NotNil(t, result) } +func TestHandleSwitch(t *testing.T) { + userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关 + // 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启 + + planner, err := NewUITarsPlanner(context.Background()) + require.NoError(t, err) + + testCases := []struct { + imageFile string + actionType ActionType + }{ + {"testdata/deepseek_think_off.png", ActionTypeClick}, + {"testdata/deepseek_think_on.png", ActionTypeFinished}, + {"testdata/deepseek_network_on.png", ActionTypeFinished}, + } + + for _, tc := range testCases { + imageBase64, size, err := loadImage(tc.imageFile) + require.NoError(t, err) + + opts := &PlanningOptions{ + UserInstruction: userInstruction, + Message: &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{ + URL: imageBase64, + }, + }, + }, + }, + Size: size, + } + + // Execute planning + result, err := planner.Call(opts) + + // Validate results + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, result.NextActions[0].ActionType, tc.actionType, + "Unexpected action type for image file: %s", tc.imageFile) + } +} + func TestValidateInput(t *testing.T) { imageBase64, size, err := loadImage("testdata/popup_risk_warning.png") require.NoError(t, err) @@ -212,7 +266,7 @@ func TestValidateInput(t *testing.T) { Role: schema.User, MultiContent: []schema.ChatMessagePart{ { - Type: "image_url", + Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{ URL: imageBase64, }, @@ -228,41 +282,46 @@ func TestValidateInput(t *testing.T) { opts: &PlanningOptions{ UserInstruction: "", Message: &schema.Message{ - Role: schema.User, - Content: "", + Role: schema.User, + MultiContent: []schema.ChatMessagePart{}, }, Size: size, }, - wantErr: ErrEmptyInstruction, + wantErr: code.LLMPrepareRequestError, }, { name: "empty conversation history", opts: &PlanningOptions{ UserInstruction: "点击立即卸载按钮", Message: &schema.Message{}, + Size: size, }, - wantErr: ErrNoConversationHistory, + wantErr: code.LLMPrepareRequestError, }, { name: "invalid image data", opts: &PlanningOptions{ UserInstruction: "点击继续使用按钮", Message: &schema.Message{ - Role: schema.User, - Content: "no image", + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Text: "no image", + }, + }, }, Size: size, }, - wantErr: ErrInvalidImageData, + wantErr: code.LLMPrepareRequestError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateInput(tt.opts) + err := validatePlanningInput(tt.opts) if tt.wantErr != nil { assert.Error(t, err) - assert.Equal(t, tt.wantErr, err) } else { assert.NoError(t, err) } @@ -361,3 +420,42 @@ func TestLoadImage(t *testing.T) { assert.Greater(t, jpegSize.Width, 0) assert.Greater(t, jpegSize.Height, 0) } + +// loadImage loads image and returns base64 encoded string +func loadImage(imagePath string) (base64Str string, size types.Size, err error) { + // Read the image file + imageFile, err := os.OpenFile(imagePath, os.O_RDONLY, 0o600) + if err != nil { + return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err) + } + defer imageFile.Close() + + // Decode the image to get its resolution + imageData, format, err := image.Decode(imageFile) + if err != nil { + return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err) + } + + // Get the resolution of the image + width := imageData.Bounds().Dx() + height := imageData.Bounds().Dy() + size = types.Size{Width: width, Height: height} + + // Convert image to base64 + buf := new(bytes.Buffer) + // 根据图像格式选择正确的编码器 + if format == "jpeg" || format == "jpg" { + if err := jpeg.Encode(buf, imageData, nil); err != nil { + return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err) + } + } else { + // 默认使用 PNG 编码 + if err := png.Encode(buf, imageData); err != nil { + return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err) + } + } + base64Str = fmt.Sprintf("data:image/%s;base64,%s", format, + base64.StdEncoding.EncodeToString(buf.Bytes())) + + return base64Str, size, nil +} diff --git a/uixt/ai/planner_ui_tars.go b/uixt/ai/planner_ui_tars.go index 57dc16af..83184233 100644 --- a/uixt/ai/planner_ui_tars.go +++ b/uixt/ai/planner_ui_tars.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "math" - "os" "regexp" "strconv" "strings" @@ -14,55 +13,12 @@ import ( "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" "github.com/httprunner/httprunner/v5/code" - "github.com/httprunner/httprunner/v5/internal/config" "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/pkg/errors" "github.com/rs/zerolog/log" ) -const ( - EnvArkBaseURL = "ARK_BASE_URL" - EnvArkAPIKey = "ARK_API_KEY" - EnvArkModelID = "ARK_MODEL_ID" -) - -func GetArkModelConfig() (*ark.ChatModelConfig, error) { - if err := config.LoadEnv(); err != nil { - return nil, errors.Wrap(code.LoadEnvError, err.Error()) - } - - arkBaseURL := os.Getenv(EnvArkBaseURL) - arkAPIKey := os.Getenv(EnvArkAPIKey) - if arkAPIKey == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvArkAPIKey) - } - modelName := os.Getenv(EnvArkModelID) - if modelName == "" { - return nil, errors.Wrapf(code.LLMEnvMissedError, - "env %s missed", EnvArkModelID) - } - timeout := defaultTimeout - temp := float32(0.7) - modelConfig := &ark.ChatModelConfig{ - BaseURL: arkBaseURL, - APIKey: arkAPIKey, - Model: modelName, - Temperature: &temp, - Timeout: &timeout, - } - - // log config info - log.Info().Str("model", modelConfig.Model). - Str("baseURL", modelConfig.BaseURL). - Str("apiKey", maskAPIKey(modelConfig.APIKey)). - Str("timeout", defaultTimeout.String()). - Msg("get model config") - - return modelConfig, nil -} - func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) { config, err := GetArkModelConfig() if err != nil { @@ -75,60 +31,32 @@ func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) { return &UITarsPlanner{ ctx: ctx, - config: config, model: chatModel, + modelType: LLMServiceTypeUITARS, systemPrompt: uiTarsPlanningPrompt, }, nil } -// https://www.volcengine.com/docs/82379/1536429 -const uiTarsPlanningPrompt = ` -You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. - -## Output Format -` + "```" + ` -Thought: ... -Action: ... -` + "```" + ` - -## Action Space -click(start_box='[x1, y1, x2, y2]') -left_double(start_box='[x1, y1, x2, y2]') -right_single(start_box='[x1, y1, x2, y2]') -drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]') -hotkey(key='') -type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `. -scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left') -wait() #Sleep for 5s and take a screenshot to check for any changes. -finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. - -## Note -- Use Chinese in ` + "`Thought`" + ` part. -- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part. - -## User Instruction -` - type UITarsPlanner struct { ctx context.Context model model.ToolCallingChatModel - config *ark.ChatModelConfig systemPrompt string - history []*schema.Message // conversation history + modelType LLMServiceType + history ConversationHistory } // Call performs UI planning using Vision Language Model func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) { // validate input parameters - if err := validateInput(opts); err != nil { - return nil, errors.Wrap(err, "validate input parameters failed") + if err := validatePlanningInput(opts); err != nil { + return nil, errors.Wrap(err, "validate planning parameters failed") } // prepare prompt if len(p.history) == 0 { // add system message systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction - p.history = []*schema.Message{ + p.history = ConversationHistory{ { Role: schema.System, Content: systemPrompt, @@ -136,27 +64,27 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) { } } // append user image message - appendConversationHistory(p.history, opts.Message) + p.history.Append(opts.Message) // call model service, generate response logRequest(p.history) startTime := time.Now() resp, err := p.model.Generate(p.ctx, p.history) log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", p.config.Model).Msg("call model service") + Str("model", string(p.modelType)).Msg("call model service") if err != nil { - return nil, fmt.Errorf("request model service failed: %w", err) + return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) } logResponse(resp) // parse result result, err := p.parseResult(resp, opts.Size) if err != nil { - return nil, errors.Wrap(err, "parse result failed") + return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error()) } // append assistant message - appendConversationHistory(p.history, &schema.Message{ + p.history.Append(&schema.Message{ Role: schema.Assistant, Content: result.ActionSummary, }) diff --git a/uixt/ai/session.go b/uixt/ai/session.go new file mode 100644 index 00000000..659ccc88 --- /dev/null +++ b/uixt/ai/session.go @@ -0,0 +1,103 @@ +package ai + +import ( + "strings" + + "github.com/cloudwego/eino/schema" + "github.com/rs/zerolog/log" +) + +// ConversationHistory represents a sequence of chat messages +type ConversationHistory []*schema.Message + +// Append adds a new message to the conversation history +func (h *ConversationHistory) Append(msg *schema.Message) { + // for user image message: + // - keep at most 4 user image messages + // - delete the oldest user image message when the limit is reached + if msg.Role == schema.User { + // get all existing user messages + userImgCount := 0 + firstUserImgIndex := -1 + + // calculate the number of user messages and find the index of the first user message + for i, item := range *h { + if item.Role == schema.User { + userImgCount++ + if firstUserImgIndex == -1 { + firstUserImgIndex = i + } + } + } + + // if there are already 4 user messages, delete the first one before adding the new message + if userImgCount >= 4 && firstUserImgIndex >= 0 { + // delete the first user message + *h = append( + (*h)[:firstUserImgIndex], + (*h)[firstUserImgIndex+1:]..., + ) + } + // add the new user message to the history + *h = append(*h, msg) + } + + // for assistant message: + // - keep at most the last 10 assistant messages + if msg.Role == schema.Assistant { + // add the new assistant message to the history + *h = append(*h, msg) + + // if there are more than 10 assistant messages, remove the oldest ones + assistantMsgCount := 0 + for i := len(*h) - 1; i >= 0; i-- { + if (*h)[i].Role == schema.Assistant { + assistantMsgCount++ + if assistantMsgCount > 10 { + *h = append((*h)[:i], (*h)[i+1:]...) + } + } + } + } +} + +func logRequest(messages ConversationHistory) { + msgs := make(ConversationHistory, 0, len(messages)) + for _, message := range messages { + msg := &schema.Message{ + Role: message.Role, + } + if message.Content != "" { + msg.Content = message.Content + } else if len(message.MultiContent) > 0 { + for _, mc := range message.MultiContent { + switch mc.Type { + case schema.ChatMessagePartTypeImageURL: + // Create a copy of the ImageURL to avoid modifying the original message + imageURLCopy := *mc.ImageURL + if strings.HasPrefix(imageURLCopy.URL, "data:image/") { + imageURLCopy.URL = "" + } + msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{ + Type: mc.Type, + ImageURL: &imageURLCopy, + }) + } + } + } + msgs = append(msgs, msg) + } + log.Debug().Interface("messages", msgs).Msg("log request messages") +} + +func logResponse(resp *schema.Message) { + logger := log.Info().Str("role", string(resp.Role)). + Str("content", resp.Content) + if resp.ResponseMeta != nil { + logger = logger.Interface("response_meta", resp.ResponseMeta) + } + if resp.Extra != nil { + logger = logger.Interface("extra", resp.Extra) + } + logger.Msg("log response message") +} diff --git a/uixt/ai/testdata/deepseek_network_on.png b/uixt/ai/testdata/deepseek_network_on.png new file mode 100644 index 00000000..384143dd Binary files /dev/null and b/uixt/ai/testdata/deepseek_network_on.png differ diff --git a/uixt/ai/testdata/deepseek_think_off.png b/uixt/ai/testdata/deepseek_think_off.png new file mode 100644 index 00000000..a87086eb Binary files /dev/null and b/uixt/ai/testdata/deepseek_think_off.png differ diff --git a/uixt/ai/testdata/deepseek_think_on.png b/uixt/ai/testdata/deepseek_think_on.png new file mode 100644 index 00000000..ea1f6745 Binary files /dev/null and b/uixt/ai/testdata/deepseek_think_on.png differ diff --git a/uixt/ai/testdata/llk_4.png b/uixt/ai/testdata/llk_4.png new file mode 100644 index 00000000..35a05aaf Binary files /dev/null and b/uixt/ai/testdata/llk_4.png differ diff --git a/uixt/android_driver_adb.go b/uixt/android_driver_adb.go index 7e617bde..34cab417 100644 --- a/uixt/android_driver_adb.go +++ b/uixt/android_driver_adb.go @@ -701,7 +701,7 @@ func (ad *ADBDriver) StopCaptureLog() (result interface{}, err error) { return pointRes, nil } - reader, err := os.Open(files[0]) + reader, err := os.OpenFile(files[0], os.O_RDONLY, 0o600) if err != nil { log.Info().Msg("open File error") return pointRes, nil diff --git a/uixt/driver_ext_ai.go b/uixt/driver_ext_ai.go index 17df0dba..286d90ee 100644 --- a/uixt/driver_ext_ai.go +++ b/uixt/driver_ext_ai.go @@ -55,14 +55,6 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error { return nil } -func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (string, error) { - return "", nil -} - -func (dExt *XTDriver) AIAssert(text string, opts ...option.ActionOption) error { - return nil -} - func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) { if dExt.LLMService == nil { return nil, errors.New("LLM service is not initialized") @@ -116,3 +108,45 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) ( } return result, nil } + +func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (string, error) { + return "", nil +} + +func (dExt *XTDriver) AIAssert(assertion string, opts ...option.ActionOption) error { + if dExt.LLMService == nil { + return errors.New("LLM service is not initialized") + } + + compressedBufSource, err := dExt.GetScreenShotBuffer() + if err != nil { + return err + } + + // convert buffer to base64 string + screenShotBase64 := "data:image/jpeg;base64," + + base64.StdEncoding.EncodeToString(compressedBufSource.Bytes()) + + // get window size + size, err := dExt.IDriver.WindowSize() + if err != nil { + return errors.Wrap(err, "get window size for AI assertion failed") + } + + // execute assertion + assertOpts := &ai.AssertOptions{ + Assertion: assertion, + Screenshot: screenShotBase64, + Size: size, + } + result, err := dExt.LLMService.Assert(assertOpts) + if err != nil { + return errors.Wrap(err, "AI assertion failed") + } + + if !result.Pass { + return errors.New(result.Thought) + } + + return nil +} diff --git a/uixt/driver_ext_test.go b/uixt/driver_ext_test.go index 8113d5ef..1ba1646c 100644 --- a/uixt/driver_ext_test.go +++ b/uixt/driver_ext_test.go @@ -127,6 +127,9 @@ func TestDriverExt_TapByLLM(t *testing.T) { driver := setupDriverExt(t) err := driver.AIAction("点击第一个帖子的作者头像") assert.Nil(t, err) + + err = driver.AIAssert("当前在个人介绍页") + assert.Nil(t, err) } func TestDriverExt_StartToGoal(t *testing.T) { diff --git a/uixt/driver_utils.go b/uixt/driver_utils.go index ec57ea21..1269cd55 100644 --- a/uixt/driver_utils.go +++ b/uixt/driver_utils.go @@ -185,8 +185,8 @@ func (dExt *XTDriver) DoValidation(check, assert, expected string, message ...st switch check { case SelectorOCR: err = dExt.assertOCR(expected, assert) - // case SelectorAI: - // // TODO + case SelectorAI: + err = dExt.AIAssert(assert) case SelectorForegroundApp: err = dExt.assertForegroundApp(expected, assert) default: diff --git a/uixt/ios_driver_wda.go b/uixt/ios_driver_wda.go index 7d09948c..e3a96c61 100644 --- a/uixt/ios_driver_wda.go +++ b/uixt/ios_driver_wda.go @@ -970,7 +970,7 @@ func (wd *WDADriver) StartCaptureLog(identifier ...string) error { func (wd *WDADriver) PushImage(localPath string) error { log.Info().Str("localPath", localPath).Msg("WDADriver.PushImage") - localFile, err := os.Open(localPath) + localFile, err := os.OpenFile(localPath, os.O_RDONLY, 0o600) if err != nil { return err }