Merge branch 'feat-ai-assert' into 'master'

新增 AIAssert 功能

See merge request iesqa/httprunner!81
This commit is contained in:
李隆
2025-04-29 12:15:27 +00:00
31 changed files with 936 additions and 374 deletions

View File

@@ -18,6 +18,9 @@ Compared to other UI automation frameworks, HttpRunner's main features include:
- Unified API across multiple platforms, reducing learning and horizontal expansion costs
- Embracing the open-source ecosystem, fully reusing open-source components
> [HttpRunner v5 用户指南(更新中)](https://debugtalk.feishu.cn/wiki/RqGuw17bsizGTik9WuNcGQyhnaf)
> [HttpRunner DeepWiki](https://deepwiki.com/httprunner/httprunner)
## Usage
```text
$ hrp -h

View File

@@ -18,7 +18,8 @@ HttpRunner 相比其它 UI 自动化框架,主要特点包括:
- 多端统一 API降低学习和横向拓展的成本
- 拥抱开源生态,充分复用开源组件
[HttpRunner v5 用户指南(更新中)](https://debugtalk.feishu.cn/wiki/RqGuw17bsizGTik9WuNcGQyhnaf)
> [HttpRunner v5 用户指南(更新中)](https://debugtalk.feishu.cn/wiki/RqGuw17bsizGTik9WuNcGQyhnaf)
> [HttpRunner DeepWiki](https://deepwiki.com/httprunner/httprunner)
## 使用说明

View File

@@ -96,16 +96,15 @@ var (
LoopActionNotFoundError = errors.New("loop action not found error") // 79
)
// AI related: [80, 90)
// CV related: [80, 90)
var (
CVEnvMissedError = errors.New("CV env missed error") // 80
CVRequestError = errors.New("CV prepare request error") // 81
CVServiceConnectionError = errors.New("CV service connect error") // 82
CVResponseError = errors.New("CV parse response error") // 83
CVResultNotFoundError = errors.New("CV result not found") // 84
CVEnvMissedError = errors.New("CV env missed error") // 80
CVPrepareRequestError = errors.New("CV prepare request error") // 81
CVRequestServiceError = errors.New("CV request service error") // 82
CVParseResponseError = errors.New("CV parse response error") // 83
CVResultNotFoundError = errors.New("CV result not found") // 84
LLMEnvMissedError = errors.New("LLM env missed error") // 85
StateUnknowError = errors.New("detect state failed") // 89
StateUnknowError = errors.New("detect state failed") // 85
)
// trackings related: [90, 100)
@@ -121,6 +120,15 @@ var (
RiskControlAccountActivation = errors.New("risk control account activation") // 102
)
// LLM related: [110, 120)
var (
LLMEnvMissedError = errors.New("missed LLM env error") // 110
LLMPrepareRequestError = errors.New("prepare LLM request error") // 111
LLMRequestServiceError = errors.New("request LLM service error") // 112
LLMParsePlanningResponseError = errors.New("parse LLM planning response error") // 113
LLMParseAssertionResponseError = errors.New("parse LLM assertion response error") // 114
)
var errorsMap = map[error]int{
// environment
ConfigureError: 3,
@@ -194,14 +202,21 @@ var errorsMap = map[error]int{
MobileUIPopupError: 78,
LoopActionNotFoundError: 79,
// AI related
CVEnvMissedError: 80,
CVRequestError: 81,
CVServiceConnectionError: 82,
CVResponseError: 83,
CVResultNotFoundError: 84,
LLMEnvMissedError: 85,
StateUnknowError: 89,
// CV related
CVEnvMissedError: 80,
CVPrepareRequestError: 81,
CVRequestServiceError: 82,
CVParseResponseError: 83,
CVResultNotFoundError: 84,
StateUnknowError: 85,
// LLM related
LLMEnvMissedError: 110,
LLMPrepareRequestError: 111,
LLMRequestServiceError: 112,
LLMParsePlanningResponseError: 113,
LLMParseAssertionResponseError: 114,
// trackings related
TrackingGetError: 90,

View File

@@ -4,6 +4,7 @@ import (
"reflect"
"github.com/httprunner/httprunner/v5/internal/builtin"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
)
@@ -42,6 +43,8 @@ type TConfig struct {
Path string `json:"path,omitempty" yaml:"path,omitempty"` // testcase file path
PluginSetting *PluginConfig `json:"plugin,omitempty" yaml:"plugin,omitempty"` // plugin config
IgnorePopup bool `json:"ignore_popup,omitempty" yaml:"ignore_popup,omitempty"`
LLMService ai.LLMServiceType `json:"llm_service,omitempty" yaml:"llm_service,omitempty"`
CVService ai.CVServiceType `json:"cv_service,omitempty" yaml:"cv_service,omitempty"`
}
func (c *TConfig) Get() *TConfig {
@@ -108,6 +111,18 @@ func (c *TConfig) SetWeight(weight int) *TConfig {
return c
}
// SetLLMService sets LLM service for current testcase.
func (c *TConfig) SetLLMService(llmService ai.LLMServiceType) *TConfig {
c.LLMService = llmService
return c
}
// SetCVService sets CV service for current testcase.
func (c *TConfig) SetCVService(cvService ai.CVServiceType) *TConfig {
c.CVService = cvService
return c
}
func (c *TConfig) SetWebSocket(times, interval, timeout, size int64) *TConfig {
c.WebSocketSetting = &WebSocketConfig{
ReconnectionTimes: times,

View File

@@ -119,7 +119,7 @@ func LoadCurlCase(path string) (*hrp.TestCaseDef, error) {
}
func readFileLines(path string) ([]string, error) {
file, err := os.Open(path)
file, err := os.OpenFile(path, os.O_RDONLY, 0o600)
if err != nil {
log.Error().Err(err).Str("path", path).Msg("open file failed")
return nil, err

View File

@@ -1 +1 @@
v5.0.0-beta-2504271150
v5.0.0-beta-2504292008

View File

@@ -536,7 +536,7 @@ func (d *Device) List(remotePath string) (devFileInfos []DeviceFileInfo, err err
}
func (d *Device) PushFile(localPath, remotePath string, modification ...time.Time) (err error) {
localFile, err := os.Open(localPath)
localFile, err := os.OpenFile(localPath, os.O_RDONLY, 0o600)
if err != nil {
return err
}
@@ -645,7 +645,7 @@ func (d *Device) installViaABBExec(apk io.ReadSeeker, args ...string) (raw []byt
}
func (d *Device) InstallAPK(apkPath string, args ...string) (string, error) {
apkFile, err := os.Open(apkPath)
apkFile, err := os.OpenFile(apkPath, os.O_RDONLY, 0o600)
if err != nil {
return "", errors.Wrap(err, fmt.Sprintf("open apk file %s failed", apkPath))
}

View File

@@ -418,6 +418,17 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) {
}
r.parametersIterator = parametersIterator
// ai options
aiOpts := []ai.AIServiceOption{}
if parsedConfig.LLMService != "" {
aiOpts = append(aiOpts, ai.WithLLMService(parsedConfig.LLMService))
}
if parsedConfig.CVService == "" {
// default to vedem
parsedConfig.CVService = ai.CVServiceTypeVEDEM
}
aiOpts = append(aiOpts, ai.WithCVService(parsedConfig.CVService))
// parse android devices config
for _, androidDeviceOptions := range parsedConfig.Android {
err := r.parseDeviceConfig(androidDeviceOptions, parsedConfig.Variables)
@@ -435,7 +446,7 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) {
return nil, errors.Wrap(err, "init android driver failed")
}
driverExt := uixt.NewXTDriver(driver, ai.WithCVService(ai.CVServiceTypeVEDEM))
driverExt := uixt.NewXTDriver(driver, aiOpts...)
r.uixtDrivers[androidDeviceOptions.SerialNumber] = driverExt
}
// parse iOS devices config
@@ -455,7 +466,7 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) {
return nil, errors.Wrap(err, "init ios driver failed")
}
driverExt := uixt.NewXTDriver(driver, ai.WithCVService(ai.CVServiceTypeVEDEM))
driverExt := uixt.NewXTDriver(driver, aiOpts...)
r.uixtDrivers[iosDeviceOptions.UDID] = driverExt
}
// parse harmony devices config
@@ -475,7 +486,7 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) {
return nil, errors.Wrap(err, "init harmony driver failed")
}
driverExt := uixt.NewXTDriver(driver, ai.WithCVService(ai.CVServiceTypeVEDEM))
driverExt := uixt.NewXTDriver(driver, aiOpts...)
r.uixtDrivers[harmonyDeviceOptions.ConnectKey] = driverExt
}

View File

@@ -95,7 +95,7 @@ func (s *Summary) GenHTMLReport() error {
}
reportPath := filepath.Join(reportsDir, "report.html")
file, err := os.OpenFile(reportPath, os.O_WRONLY|os.O_CREATE, 0o666)
file, err := os.OpenFile(reportPath, os.O_WRONLY|os.O_CREATE, 0o600)
if err != nil {
log.Error().Err(err).Msg("open file failed")
return err

View File

@@ -49,26 +49,60 @@ type LLMServiceType string
const (
LLMServiceTypeUITARS LLMServiceType = "ui-tars"
LLMServiceTypeGPT4o LLMServiceType = "gpt-4o"
LLMServiceTypeGPT4Vision LLMServiceType = "gpt-4-vision"
LLMServiceTypeQwenVL LLMServiceType = "qwen-vl"
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
)
func WithLLMService(service LLMServiceType) AIServiceOption {
// ILLMService 定义了 LLM 服务接口,包括规划和断言功能
type ILLMService interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
Assert(opts *AssertOptions) (*AssertionResponse, error)
}
func WithLLMService(modelType LLMServiceType) AIServiceOption {
return func(opts *AIServices) {
if service == LLMServiceTypeGPT4o {
var err error
opts.ILLMService, err = NewPlanner(context.Background())
if err != nil {
log.Error().Err(err).Msg("init gpt-4o llm service failed")
os.Exit(code.GetErrorCode(err))
}
// init planner
var planner IPlanner
var err error
switch modelType {
case LLMServiceTypeGPT4o:
// TODO: implement gpt-4o planner and asserter
planner, err = NewPlanner(context.Background())
case LLMServiceTypeUITARS:
planner, err = NewUITarsPlanner(context.Background())
}
if service == LLMServiceTypeUITARS {
var err error
opts.ILLMService, err = NewUITarsPlanner(context.Background())
if err != nil {
log.Error().Err(err).Msg("init ui-tars llm service failed")
os.Exit(code.GetErrorCode(err))
}
if err != nil {
log.Error().Err(err).Msgf("init %s planner failed", modelType)
os.Exit(code.GetErrorCode(err))
}
// init asserter
asserter, err := NewAsserter(context.Background(), modelType)
if err != nil {
log.Error().Err(err).Msgf("init %s asserter failed", modelType)
os.Exit(code.GetErrorCode(err))
}
opts.ILLMService = &combinedLLMService{
planner: planner,
asserter: asserter,
}
}
}
// combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能
type combinedLLMService struct {
planner IPlanner // 提供规划功能
asserter IAsserter // 提供断言功能
}
// Call 执行规划功能
func (c *combinedLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
return c.planner.Call(opts)
}
// Assert 执行断言功能
func (c *combinedLLMService) Assert(opts *AssertOptions) (*AssertionResponse, error) {
return c.asserter.Assert(opts)
}

62
uixt/ai/ai_ark.go Normal file
View File

@@ -0,0 +1,62 @@
package ai
import (
"os"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
const (
EnvArkBaseURL = "ARK_BASE_URL"
EnvArkAPIKey = "ARK_API_KEY"
EnvArkModelID = "ARK_MODEL_ID"
)
func GetArkModelConfig() (*ark.ChatModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
arkBaseURL := os.Getenv(EnvArkBaseURL)
arkAPIKey := os.Getenv(EnvArkAPIKey)
if arkAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvArkAPIKey)
}
modelName := os.Getenv(EnvArkModelID)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvArkModelID)
}
timeout := defaultTimeout
// https://www.volcengine.com/docs/82379/1494384?redirect=1
temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。
// topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。
// maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。
// frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。
modelConfig := &ark.ChatModelConfig{
BaseURL: arkBaseURL,
APIKey: arkAPIKey,
Model: modelName,
Timeout: &timeout,
Temperature: &temperature,
// TopP: &topP,
// MaxTokens: &maxTokens,
// FrequencyPenalty: &frequencyPenalty,
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}

79
uixt/ai/ai_openai.go Normal file
View File

@@ -0,0 +1,79 @@
package ai
import (
"os"
"github.com/cloudwego/eino-ext/components/model/openai"
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
"github.com/getkin/kin-openapi/openapi3gen"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
const (
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
EnvOpenAIAPIKey = "OPENAI_API_KEY"
EnvModelName = "LLM_MODEL_NAME"
)
// GetOpenAIModelConfig get OpenAI config
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
if openaiBaseURL == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIBaseURL)
}
openaiAPIKey := os.Getenv(EnvOpenAIAPIKey)
if openaiAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIAPIKey)
}
modelName := os.Getenv(EnvModelName)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvModelName)
}
type OutputFormat struct {
Thought string `json:"thought"`
Action string `json:"action"`
Error string `json:"error,omitempty"`
}
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
if err != nil {
return nil, err
}
modelConfig := &openai.ChatModelConfig{
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
// set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
ResponseFormat: &openai2.ChatCompletionResponseFormat{
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
Name: "thought_and_action",
Description: "data that describes planning thought and action",
Schema: outputFormatSchema.Value,
Strict: false,
},
},
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}

218
uixt/ai/asserter.go Normal file
View File

@@ -0,0 +1,218 @@
package ai
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
// IAsserter interface defines the contract for assertion operations
type IAsserter interface {
Assert(opts *AssertOptions) (*AssertionResponse, error)
}
// AssertOptions represents the input options for assertion
type AssertOptions struct {
Assertion string `json:"assertion"` // The assertion text to verify
Screenshot string `json:"screenshot"` // Base64 encoded screenshot
Size types.Size `json:"size"` // Screen dimensions
}
// AssertionResponse represents the response from an AI assertion
type AssertionResponse struct {
Pass bool `json:"pass"`
Thought string `json:"thought"`
}
// Asserter handles assertion using different AI models
type Asserter struct {
ctx context.Context
model model.ToolCallingChatModel
systemPrompt string
history ConversationHistory
modelType LLMServiceType
}
// NewAsserter creates a new Asserter instance
func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, error) {
asserter := &Asserter{
ctx: ctx,
modelType: modelType,
systemPrompt: getAssertionSystemPrompt(modelType),
}
switch modelType {
case LLMServiceTypeUITARS:
config, err := GetArkModelConfig()
if err != nil {
return nil, err
}
asserter.model, err = ark.NewChatModel(ctx, config)
if err != nil {
return nil, err
}
case LLMServiceTypeGPT4Vision, LLMServiceTypeGPT4o:
config, err := GetOpenAIModelConfig()
if err != nil {
return nil, err
}
asserter.model, err = openai.NewChatModel(ctx, config)
if err != nil {
return nil, err
}
default:
return nil, errors.New("not supported model type for asserter")
}
return asserter, nil
}
// getAssertionSystemPrompt returns the appropriate system prompt for the given model type
func getAssertionSystemPrompt(modelType LLMServiceType) string {
if modelType == LLMServiceTypeUITARS {
return defaultAssertionPrompt + "\n\n" + uiTarsAssertionResponseFormat
}
return defaultAssertionPrompt + "\n\n" + defaultAssertionResponseJsonFormat
}
// Assert performs the assertion check on the screenshot
func (a *Asserter) Assert(opts *AssertOptions) (*AssertionResponse, error) {
// Validate input parameters
if err := validateAssertionInput(opts); err != nil {
return nil, errors.Wrap(err, "validate assertion parameters failed")
}
// Reset history for each new assertion
a.history = ConversationHistory{
{
Role: schema.System,
Content: a.systemPrompt,
},
}
// Create user message with screenshot and assertion
userMsg := &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: opts.Screenshot,
Detail: schema.ImageURLDetailAuto,
},
},
{
Type: schema.ChatMessagePartTypeText,
Text: fmt.Sprintf(`
Here is the assertion. Please tell whether it is truthy according to the screenshot.
=====================================
%s
=====================================
`, opts.Assertion),
},
},
}
// Append user message to history
a.history.Append(userMsg)
// Call model service, generate response
logRequest(a.history)
startTime := time.Now()
resp, err := a.model.Generate(a.ctx, a.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", string(a.modelType)).Msg("call model service for assertion")
if err != nil {
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
}
logResponse(resp)
// Parse result
result, err := parseAssertionResult(resp.Content)
if err != nil {
return nil, errors.Wrap(code.LLMParseAssertionResponseError, err.Error())
}
// Append assistant message to history
a.history.Append(&schema.Message{
Role: schema.Assistant,
Content: resp.Content,
})
return result, nil
}
// validateAssertionInput validates the input parameters for assertion
func validateAssertionInput(opts *AssertOptions) error {
if opts.Assertion == "" {
return errors.Wrap(code.LLMPrepareRequestError, "assertion text is required")
}
if opts.Screenshot == "" {
return errors.Wrap(code.LLMPrepareRequestError, "screenshot is required")
}
return nil
}
// parseAssertionResult parses the model response into AssertionResponse
func parseAssertionResult(content string) (*AssertionResponse, error) {
// Extract JSON content from response
jsonContent := extractJSON(content)
if jsonContent == "" {
return nil, errors.New("could not extract JSON from response")
}
// Parse JSON response
var result AssertionResponse
if err := json.Unmarshal([]byte(jsonContent), &result); err != nil {
return nil, errors.Wrap(code.LLMParseAssertionResponseError, err.Error())
}
return &result, nil
}
// extractJSON extracts JSON content from a string that might contain markdown or other formatting
func extractJSON(content string) string {
content = strings.TrimSpace(content)
// If the content is already a valid JSON, return it
if strings.HasPrefix(content, "{") && strings.HasSuffix(content, "}") {
return content
}
// Try to extract JSON from markdown code blocks
jsonRegex := regexp.MustCompile(`(?:json)?\s*({[\s\S]*?})\s*`)
matches := jsonRegex.FindStringSubmatch(content)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
// Try a more robust approach for JSON with Chinese characters
startIdx := strings.Index(content, "{")
if startIdx >= 0 {
depth := 1
for i := startIdx + 1; i < len(content); i++ {
if content[i] == '{' {
depth++
} else if content[i] == '}' {
depth--
if depth == 0 {
return content[startIdx : i+1]
}
}
}
}
return content
}

View File

@@ -0,0 +1,25 @@
package ai
// Default assertion system prompt
const defaultAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy.`
// Default assertion response format
const defaultAssertionResponseJsonFormat = `Return in the following JSON format:
{
pass: boolean, // whether the assertion is truthy
thought: string | null, // string, if the result is falsy, give the reason why it is falsy. Otherwise, put null.
}`
// UI-TARS assertion response format
const uiTarsAssertionResponseFormat = `## Output Json String Format
` + "```" + `
"{
"pass": <<is a boolean value from the enum [true, false], true means the assertion is truthy>>,
"thought": "<<is a string, give the reason why the assertion is falsy or truthy. Otherwise.>>"
}"
` + "```" + `
## Rules **MUST** follow
- Make sure to return **only** the JSON, with **no additional** text or explanations.
- Use Chinese in ` + "`Thought`" + ` part.
- You **MUST** strictly follow up the **Output Json String Format**.`

103
uixt/ai/asserter_test.go Normal file
View File

@@ -0,0 +1,103 @@
package ai
import (
"testing"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createAIService(t *testing.T) *AIServices {
aiService := NewAIService(WithLLMService(LLMServiceTypeUITARS))
require.NotNil(t, aiService)
require.NotNil(t, aiService.ILLMService)
return aiService
}
// 测试有效断言
func TestValidAssertions(t *testing.T) {
aiService := createAIService(t)
testCases := []struct {
name string
assertion string
imagePath string
expectPass bool
}{
{
name: "深度思考功能已开启",
assertion: "输入框下方的「深度思考」文字是蓝色的",
imagePath: "testdata/deepseek_think_on.png",
expectPass: true,
},
{
name: "深度思考功能未开启",
assertion: "输入框下方的「深度思考」文字是灰色的",
imagePath: "testdata/deepseek_think_off.png",
expectPass: true,
},
{
name: "联网搜索功能已开启",
assertion: "输入框下方的「联网搜索」文字是蓝色的",
imagePath: "testdata/deepseek_network_on.png",
expectPass: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
imageBase64, size, err := loadImage(tc.imagePath)
require.NoError(t, err)
result, err := aiService.ILLMService.Assert(&AssertOptions{
Assertion: tc.assertion,
Screenshot: imageBase64,
Size: size,
})
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, tc.expectPass, result.Pass)
assert.NotEmpty(t, result.Thought)
})
}
}
// 测试无效参数
func TestInvalidParameters(t *testing.T) {
aiService := createAIService(t)
testCases := []struct {
name string
assertion string
screenshot string
size types.Size
expectedError string
}{
{
name: "缺少截图",
assertion: "测试断言",
screenshot: "",
size: types.Size{},
expectedError: "screenshot is required",
},
{
name: "缺少断言",
assertion: "",
screenshot: "some-base64-data",
size: types.Size{},
expectedError: "assertion text is required",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := aiService.ILLMService.Assert(&AssertOptions{
Assertion: tc.assertion,
Screenshot: tc.screenshot,
Size: tc.size,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}

View File

@@ -52,7 +52,7 @@ func (s *vedemCVService) ReadFromPath(imagePath string, opts ...option.ActionOpt
imageResult *CVResult, err error) {
imageBuf, err := os.ReadFile(imagePath)
if err != nil {
err = errors.Wrap(code.CVRequestError,
err = errors.Wrap(code.CVPrepareRequestError,
fmt.Sprintf("read image file error: %v", err))
return
}
@@ -116,21 +116,21 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
formWriter, err := bodyWriter.CreateFormFile("image", "screenshot.png")
if err != nil {
err = errors.Wrap(code.CVRequestError,
err = errors.Wrap(code.CVPrepareRequestError,
fmt.Sprintf("create form file error: %v", err))
return
}
size, err := formWriter.Write(imageBuf.Bytes())
if err != nil {
err = errors.Wrap(code.CVRequestError,
err = errors.Wrap(code.CVPrepareRequestError,
fmt.Sprintf("write form error: %v", err))
return
}
err = bodyWriter.Close()
if err != nil {
err = errors.Wrap(code.CVRequestError,
err = errors.Wrap(code.CVPrepareRequestError,
fmt.Sprintf("close body writer error: %v", err))
return
}
@@ -146,7 +146,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
req, err = http.NewRequest("POST", os.Getenv("VEDEM_IMAGE_URL"), copiedBodyBuf)
if err != nil {
err = errors.Wrap(code.CVRequestError,
err = errors.Wrap(code.CVPrepareRequestError,
fmt.Sprintf("construct request error: %v", err))
return
}
@@ -192,7 +192,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
break
}
if resp == nil {
err = code.CVServiceConnectionError
err = code.CVRequestServiceError
return
}
@@ -200,13 +200,13 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
results, err := io.ReadAll(resp.Body)
if err != nil {
err = errors.Wrap(code.CVResponseError,
err = errors.Wrap(code.CVParseResponseError,
fmt.Sprintf("read response body error: %v", err))
return
}
if resp.StatusCode != http.StatusOK {
err = errors.Wrap(code.CVResponseError,
err = errors.Wrap(code.CVParseResponseError,
fmt.Sprintf("unexpected response status code: %d, results: %v",
resp.StatusCode, string(results)))
return
@@ -215,13 +215,13 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
var imageResponse APIResponseImage
err = json.Unmarshal(results, &imageResponse)
if err != nil {
err = errors.Wrap(code.CVResponseError,
err = errors.Wrap(code.CVParseResponseError,
fmt.Sprintf("json unmarshal veDEM image response body error, response=%s", string(results)))
return
}
if imageResponse.Code != 0 {
err = errors.Wrap(code.CVResponseError,
err = errors.Wrap(code.CVParseResponseError,
fmt.Sprintf("unexpected response data code: %d, message: %s",
imageResponse.Code, imageResponse.Message))
return

View File

@@ -13,11 +13,12 @@ import (
"time"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/rs/zerolog/log"
"github.com/pkg/errors"
)
type ILLMService interface {
type IPlanner interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
}
@@ -57,23 +58,16 @@ const (
)
const (
defaultTimeout = 60 * time.Second
defaultTimeout = 30 * time.Second
)
// Error types
var (
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
ErrInvalidImageData = fmt.Errorf("invalid image data")
)
func validateInput(opts *PlanningOptions) error {
func validatePlanningInput(opts *PlanningOptions) error {
if opts.UserInstruction == "" {
return ErrEmptyInstruction
return errors.Wrap(code.LLMPrepareRequestError, "user instruction is empty")
}
if opts.Message == nil {
return ErrNoConversationHistory
if opts.Message == nil || opts.Message.Role == "" {
return errors.Wrap(code.LLMPrepareRequestError, "user message is empty")
}
if opts.Message.Role == schema.User {
@@ -81,7 +75,7 @@ func validateInput(opts *PlanningOptions) error {
if len(opts.Message.MultiContent) > 0 {
for _, content := range opts.Message.MultiContent {
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
return ErrInvalidImageData
return errors.Wrap(code.LLMPrepareRequestError, "invalid image data")
}
}
}
@@ -90,98 +84,6 @@ func validateInput(opts *PlanningOptions) error {
return nil
}
func logRequest(messages []*schema.Message) {
msgs := make([]*schema.Message, 0, len(messages))
for _, message := range messages {
msg := &schema.Message{
Role: message.Role,
}
if message.Content != "" {
msg.Content = message.Content
} else if len(message.MultiContent) > 0 {
for _, mc := range message.MultiContent {
switch mc.Type {
case schema.ChatMessagePartTypeImageURL:
// Create a copy of the ImageURL to avoid modifying the original message
imageURLCopy := *mc.ImageURL
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
imageURLCopy.URL = "<data:image/base64...>"
}
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
Type: mc.Type,
ImageURL: &imageURLCopy,
})
}
}
}
msgs = append(msgs, msg)
}
log.Debug().Interface("messages", msgs).Msg("log request messages")
}
func logResponse(resp *schema.Message) {
logger := log.Info().Str("role", string(resp.Role)).
Str("content", resp.Content)
if resp.ResponseMeta != nil {
logger = logger.Interface("response_meta", resp.ResponseMeta)
}
if resp.Extra != nil {
logger = logger.Interface("extra", resp.Extra)
}
logger.Msg("log response message")
}
// appendConversationHistory adds a message to the conversation history
func appendConversationHistory(history []*schema.Message, msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
if msg.Role == schema.User {
// get all existing user messages
userImgCount := 0
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range history {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
firstUserImgIndex = i
}
}
}
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
history = append(
history[:firstUserImgIndex],
history[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
history = append(history, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
history = append(history, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
history = append(history[:i], history[i+1:]...)
}
}
}
}
}
// SavePositionImg saves an image with position markers
func SavePositionImg(params struct {
InputImgBase64 string
@@ -249,37 +151,6 @@ func SavePositionImg(params struct {
return nil
}
// loadImage loads image and returns base64 encoded string
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
// Read the image file
imageFile, err := os.Open(imagePath)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
}
defer imageFile.Close()
// Decode the image to get its resolution
imageData, format, err := image.Decode(imageFile)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
}
// Get the resolution of the image
width := imageData.Bounds().Dx()
height := imageData.Bounds().Dy()
size = types.Size{Width: width, Height: height}
// Convert image to base64
buf := new(bytes.Buffer)
if err := png.Encode(buf, imageData); err != nil {
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
}
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
base64.StdEncoding.EncodeToString(buf.Bytes()))
return base64Str, size, nil
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {

View File

@@ -4,90 +4,20 @@ import (
"context"
"fmt"
_ "image/jpeg"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3gen"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
)
const (
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
EnvOpenAIAPIKey = "OPENAI_API_KEY"
EnvModelName = "LLM_MODEL_NAME"
)
// GetOpenAIModelConfig get OpenAI config
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
if openaiBaseURL == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIBaseURL)
}
openaiAPIKey := os.Getenv(EnvOpenAIAPIKey)
if openaiAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIAPIKey)
}
modelName := os.Getenv(EnvModelName)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvModelName)
}
type OutputFormat struct {
Thought string `json:"thought"`
Action string `json:"action"`
Error string `json:"error,omitempty"`
}
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
if err != nil {
return nil, err
}
modelConfig := &openai.ChatModelConfig{
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
// set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
ResponseFormat: &openai2.ChatCompletionResponseFormat{
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
Name: "thought_and_action",
Description: "data that describes planning thought and action",
Schema: outputFormatSchema.Value,
Strict: false,
},
},
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}
func NewPlanner(ctx context.Context) (*Planner, error) {
config, err := GetOpenAIModelConfig()
if err != nil {
@@ -99,8 +29,8 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
}
return &Planner{
ctx: ctx,
config: config,
model: model,
modelType: LLMServiceTypeGPT4o,
systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling
}, nil
}
@@ -108,23 +38,23 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
type Planner struct {
ctx context.Context
model model.ToolCallingChatModel
config *openai.ChatModelConfig
systemPrompt string
history []*schema.Message // conversation history
modelType LLMServiceType
history ConversationHistory
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
if err := validatePlanningInput(opts); err != nil {
return nil, errors.Wrap(err, "validate planning parameters failed")
}
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
p.history = ConversationHistory{
{
Role: schema.System,
Content: systemPrompt,
@@ -132,27 +62,27 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
}
}
// append user image message
appendConversationHistory(p.history, opts.Message)
p.history.Append(opts.Message)
// call model service, generate response
logRequest(p.history)
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", p.config.Model).Msg("call model service")
Str("model", string(p.modelType)).Msg("call model service")
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
}
logResponse(resp)
// parse result
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error())
}
// append assistant message
appendConversationHistory(p.history, &schema.Message{
p.history.Append(&schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})

View File

@@ -0,0 +1,29 @@
package ai
// https://www.volcengine.com/docs/82379/1536429
const uiTarsPlanningPrompt = `
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
` + "```" + `
Thought: ...
Action: ...
` + "```" + `
## Action Space
click(start_box='[x1, y1, x2, y2]')
left_double(start_box='[x1, y1, x2, y2]')
right_single(start_box='[x1, y1, x2, y2]')
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
hotkey(key='')
type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `.
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use Chinese in ` + "`Thought`" + ` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part.
## User Instruction
`

View File

@@ -1,11 +1,18 @@
package ai
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"image/jpeg"
"image/png"
"os"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -195,6 +202,53 @@ func TestChatList(t *testing.T) {
require.NotNil(t, result)
}
func TestHandleSwitch(t *testing.T) {
userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关
// 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启
planner, err := NewUITarsPlanner(context.Background())
require.NoError(t, err)
testCases := []struct {
imageFile string
actionType ActionType
}{
{"testdata/deepseek_think_off.png", ActionTypeClick},
{"testdata/deepseek_think_on.png", ActionTypeFinished},
{"testdata/deepseek_network_on.png", ActionTypeFinished},
}
for _, tc := range testCases {
imageBase64, size, err := loadImage(tc.imageFile)
require.NoError(t, err)
opts := &PlanningOptions{
UserInstruction: userInstruction,
Message: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
},
Size: size,
}
// Execute planning
result, err := planner.Call(opts)
// Validate results
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, result.NextActions[0].ActionType, tc.actionType,
"Unexpected action type for image file: %s", tc.imageFile)
}
}
func TestValidateInput(t *testing.T) {
imageBase64, size, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
@@ -212,7 +266,7 @@ func TestValidateInput(t *testing.T) {
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
@@ -228,41 +282,46 @@ func TestValidateInput(t *testing.T) {
opts: &PlanningOptions{
UserInstruction: "",
Message: &schema.Message{
Role: schema.User,
Content: "",
Role: schema.User,
MultiContent: []schema.ChatMessagePart{},
},
Size: size,
},
wantErr: ErrEmptyInstruction,
wantErr: code.LLMPrepareRequestError,
},
{
name: "empty conversation history",
opts: &PlanningOptions{
UserInstruction: "点击立即卸载按钮",
Message: &schema.Message{},
Size: size,
},
wantErr: ErrNoConversationHistory,
wantErr: code.LLMPrepareRequestError,
},
{
name: "invalid image data",
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
Message: &schema.Message{
Role: schema.User,
Content: "no image",
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
Text: "no image",
},
},
},
Size: size,
},
wantErr: ErrInvalidImageData,
wantErr: code.LLMPrepareRequestError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateInput(tt.opts)
err := validatePlanningInput(tt.opts)
if tt.wantErr != nil {
assert.Error(t, err)
assert.Equal(t, tt.wantErr, err)
} else {
assert.NoError(t, err)
}
@@ -361,3 +420,42 @@ func TestLoadImage(t *testing.T) {
assert.Greater(t, jpegSize.Width, 0)
assert.Greater(t, jpegSize.Height, 0)
}
// loadImage loads image and returns base64 encoded string
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
// Read the image file
imageFile, err := os.OpenFile(imagePath, os.O_RDONLY, 0o600)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
}
defer imageFile.Close()
// Decode the image to get its resolution
imageData, format, err := image.Decode(imageFile)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
}
// Get the resolution of the image
width := imageData.Bounds().Dx()
height := imageData.Bounds().Dy()
size = types.Size{Width: width, Height: height}
// Convert image to base64
buf := new(bytes.Buffer)
// 根据图像格式选择正确的编码器
if format == "jpeg" || format == "jpg" {
if err := jpeg.Encode(buf, imageData, nil); err != nil {
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
}
} else {
// 默认使用 PNG 编码
if err := png.Encode(buf, imageData); err != nil {
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
}
}
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
base64.StdEncoding.EncodeToString(buf.Bytes()))
return base64Str, size, nil
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"math"
"os"
"regexp"
"strconv"
"strings"
@@ -14,55 +13,12 @@ import (
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
const (
EnvArkBaseURL = "ARK_BASE_URL"
EnvArkAPIKey = "ARK_API_KEY"
EnvArkModelID = "ARK_MODEL_ID"
)
func GetArkModelConfig() (*ark.ChatModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
arkBaseURL := os.Getenv(EnvArkBaseURL)
arkAPIKey := os.Getenv(EnvArkAPIKey)
if arkAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvArkAPIKey)
}
modelName := os.Getenv(EnvArkModelID)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvArkModelID)
}
timeout := defaultTimeout
temp := float32(0.7)
modelConfig := &ark.ChatModelConfig{
BaseURL: arkBaseURL,
APIKey: arkAPIKey,
Model: modelName,
Temperature: &temp,
Timeout: &timeout,
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}
func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) {
config, err := GetArkModelConfig()
if err != nil {
@@ -75,60 +31,32 @@ func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) {
return &UITarsPlanner{
ctx: ctx,
config: config,
model: chatModel,
modelType: LLMServiceTypeUITARS,
systemPrompt: uiTarsPlanningPrompt,
}, nil
}
// https://www.volcengine.com/docs/82379/1536429
const uiTarsPlanningPrompt = `
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
` + "```" + `
Thought: ...
Action: ...
` + "```" + `
## Action Space
click(start_box='[x1, y1, x2, y2]')
left_double(start_box='[x1, y1, x2, y2]')
right_single(start_box='[x1, y1, x2, y2]')
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
hotkey(key='')
type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `.
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use Chinese in ` + "`Thought`" + ` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part.
## User Instruction
`
type UITarsPlanner struct {
ctx context.Context
model model.ToolCallingChatModel
config *ark.ChatModelConfig
systemPrompt string
history []*schema.Message // conversation history
modelType LLMServiceType
history ConversationHistory
}
// Call performs UI planning using Vision Language Model
func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
if err := validatePlanningInput(opts); err != nil {
return nil, errors.Wrap(err, "validate planning parameters failed")
}
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
p.history = ConversationHistory{
{
Role: schema.System,
Content: systemPrompt,
@@ -136,27 +64,27 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
}
}
// append user image message
appendConversationHistory(p.history, opts.Message)
p.history.Append(opts.Message)
// call model service, generate response
logRequest(p.history)
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", p.config.Model).Msg("call model service")
Str("model", string(p.modelType)).Msg("call model service")
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
}
logResponse(resp)
// parse result
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error())
}
// append assistant message
appendConversationHistory(p.history, &schema.Message{
p.history.Append(&schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})

103
uixt/ai/session.go Normal file
View File

@@ -0,0 +1,103 @@
package ai
import (
"strings"
"github.com/cloudwego/eino/schema"
"github.com/rs/zerolog/log"
)
// ConversationHistory represents a sequence of chat messages
type ConversationHistory []*schema.Message
// Append adds a new message to the conversation history
func (h *ConversationHistory) Append(msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
if msg.Role == schema.User {
// get all existing user messages
userImgCount := 0
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range *h {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
firstUserImgIndex = i
}
}
}
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
*h = append(
(*h)[:firstUserImgIndex],
(*h)[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
*h = append(*h, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
*h = append(*h, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(*h) - 1; i >= 0; i-- {
if (*h)[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
*h = append((*h)[:i], (*h)[i+1:]...)
}
}
}
}
}
func logRequest(messages ConversationHistory) {
msgs := make(ConversationHistory, 0, len(messages))
for _, message := range messages {
msg := &schema.Message{
Role: message.Role,
}
if message.Content != "" {
msg.Content = message.Content
} else if len(message.MultiContent) > 0 {
for _, mc := range message.MultiContent {
switch mc.Type {
case schema.ChatMessagePartTypeImageURL:
// Create a copy of the ImageURL to avoid modifying the original message
imageURLCopy := *mc.ImageURL
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
imageURLCopy.URL = "<data:image/base64...>"
}
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
Type: mc.Type,
ImageURL: &imageURLCopy,
})
}
}
}
msgs = append(msgs, msg)
}
log.Debug().Interface("messages", msgs).Msg("log request messages")
}
func logResponse(resp *schema.Message) {
logger := log.Info().Str("role", string(resp.Role)).
Str("content", resp.Content)
if resp.ResponseMeta != nil {
logger = logger.Interface("response_meta", resp.ResponseMeta)
}
if resp.Extra != nil {
logger = logger.Interface("extra", resp.Extra)
}
logger.Msg("log response message")
}

BIN
uixt/ai/testdata/deepseek_network_on.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 377 KiB

BIN
uixt/ai/testdata/deepseek_think_off.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 371 KiB

BIN
uixt/ai/testdata/deepseek_think_on.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 369 KiB

BIN
uixt/ai/testdata/llk_4.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 585 KiB

View File

@@ -701,7 +701,7 @@ func (ad *ADBDriver) StopCaptureLog() (result interface{}, err error) {
return pointRes, nil
}
reader, err := os.Open(files[0])
reader, err := os.OpenFile(files[0], os.O_RDONLY, 0o600)
if err != nil {
log.Info().Msg("open File error")
return pointRes, nil

View File

@@ -55,14 +55,6 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
return nil
}
func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (string, error) {
return "", nil
}
func (dExt *XTDriver) AIAssert(text string, opts ...option.ActionOption) error {
return nil
}
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
if dExt.LLMService == nil {
return nil, errors.New("LLM service is not initialized")
@@ -116,3 +108,45 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
}
return result, nil
}
func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (string, error) {
return "", nil
}
func (dExt *XTDriver) AIAssert(assertion string, opts ...option.ActionOption) error {
if dExt.LLMService == nil {
return errors.New("LLM service is not initialized")
}
compressedBufSource, err := dExt.GetScreenShotBuffer()
if err != nil {
return err
}
// convert buffer to base64 string
screenShotBase64 := "data:image/jpeg;base64," +
base64.StdEncoding.EncodeToString(compressedBufSource.Bytes())
// get window size
size, err := dExt.IDriver.WindowSize()
if err != nil {
return errors.Wrap(err, "get window size for AI assertion failed")
}
// execute assertion
assertOpts := &ai.AssertOptions{
Assertion: assertion,
Screenshot: screenShotBase64,
Size: size,
}
result, err := dExt.LLMService.Assert(assertOpts)
if err != nil {
return errors.Wrap(err, "AI assertion failed")
}
if !result.Pass {
return errors.New(result.Thought)
}
return nil
}

View File

@@ -127,6 +127,9 @@ func TestDriverExt_TapByLLM(t *testing.T) {
driver := setupDriverExt(t)
err := driver.AIAction("点击第一个帖子的作者头像")
assert.Nil(t, err)
err = driver.AIAssert("当前在个人介绍页")
assert.Nil(t, err)
}
func TestDriverExt_StartToGoal(t *testing.T) {

View File

@@ -185,8 +185,8 @@ func (dExt *XTDriver) DoValidation(check, assert, expected string, message ...st
switch check {
case SelectorOCR:
err = dExt.assertOCR(expected, assert)
// case SelectorAI:
// // TODO
case SelectorAI:
err = dExt.AIAssert(assert)
case SelectorForegroundApp:
err = dExt.assertForegroundApp(expected, assert)
default:

View File

@@ -970,7 +970,7 @@ func (wd *WDADriver) StartCaptureLog(identifier ...string) error {
func (wd *WDADriver) PushImage(localPath string) error {
log.Info().Str("localPath", localPath).Msg("WDADriver.PushImage")
localFile, err := os.Open(localPath)
localFile, err := os.OpenFile(localPath, os.O_RDONLY, 0o600)
if err != nil {
return err
}