feat: add TapByLLM/PlanNextAction for XTDriver

This commit is contained in:
lilong.129
2025-03-19 21:05:19 +08:00
parent ec93382e47
commit 55acaceb09
26 changed files with 290 additions and 337 deletions

View File

@@ -1,6 +1,7 @@
package ai
import (
"context"
"os"
"github.com/rs/zerolog/log"
@@ -46,6 +47,7 @@ func WithCVService(service CVServiceType) AIServiceOption {
type LLMServiceType string
const (
LLMServiceTypeUITARS LLMServiceType = "ui-tars"
LLMServiceTypeGPT4o LLMServiceType = "gpt-4o"
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
)
@@ -60,5 +62,12 @@ func WithLLMService(service LLMServiceType) AIServiceOption {
os.Exit(code.GetErrorCode(err))
}
}
if service == LLMServiceTypeUITARS {
var err error
opts.ILLMService, err = NewPlanner(context.Background())
if err != nil {
log.Error().Err(err).Msg("init ui-tars llm service failed")
}
}
}
}

View File

@@ -5,7 +5,7 @@ import "testing"
func TestOption(t *testing.T) {
options := NewAIService(
WithCVService(CVServiceTypeOpenCV),
WithLLMService(LLMServiceTypeDeepSeekV3),
WithLLMService(LLMServiceTypeUITARS),
)
t.Log(options)
}

View File

@@ -30,7 +30,7 @@ type APIResponseImage struct {
}
func NewVEDEMImageService() (*vedemCVService, error) {
if err := checkEnv(); err != nil {
if err := checkEnvCV(); err != nil {
return nil, err
}
return &vedemCVService{}, nil
@@ -230,7 +230,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
return imageResult, nil
}
func checkEnv() error {
func checkEnvCV() error {
vedemImageURL := os.Getenv("VEDEM_IMAGE_URL")
if vedemImageURL == "" {
return errors.Wrap(code.CVEnvMissedError, "VEDEM_IMAGE_URL missed")

170
uixt/ai/env.go Normal file
View File

@@ -0,0 +1,170 @@
package ai
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strconv"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/joho/godotenv"
"github.com/rs/zerolog/log"
)
const (
defaultTimeout = 60 * time.Second
)
type OpenAIInitConfig struct {
ReportURL string `json:"REPORT_SERVER_URL"`
Headers map[string]string `json:"defaultHeaders"`
}
const (
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
EnvOpenAIAPIKey = "OPENAI_API_KEY"
EnvModelName = "MIDSCENE_MODEL_NAME"
EnvOpenAIInitConfigJSON = "MIDSCENE_OPENAI_INIT_CONFIG_JSON"
EnvUseVLMUITars = "MIDSCENE_USE_VLM_UI_TARS"
)
// loadEnv loads environment variables from a file
func loadEnv(envPath string) error {
err := godotenv.Load(envPath)
if err != nil {
return err
}
log.Info().Str("path", envPath).Msg("load env success")
return nil
}
func GetEnvConfig(key string) string {
return os.Getenv(key)
}
func GetEnvConfigInJSON(key string) (map[string]interface{}, error) {
value := GetEnvConfig(key)
if value == "" {
return nil, nil
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(value), &result); err != nil {
return nil, err
}
return result, nil
}
func GetEnvConfigInBool(key string) bool {
value := GetEnvConfig(key)
if value == "" {
return false
}
boolValue, _ := strconv.ParseBool(value)
return boolValue
}
// GetEnvConfigOrDefault get env config or default value
func GetEnvConfigOrDefault(key, defaultValue string) string {
value := GetEnvConfig(key)
if value == "" {
return defaultValue
}
return value
}
func GetEnvConfigInInt(key string, defaultValue int) int {
value := GetEnvConfig(key)
if value == "" {
return defaultValue
}
intValue, err := strconv.Atoi(value)
if err != nil {
return defaultValue
}
return intValue
}
// CustomTransport is a custom RoundTripper that adds headers to every request
type CustomTransport struct {
Transport http.RoundTripper
Headers map[string]string
}
// RoundTrip executes a single HTTP transaction and adds custom headers
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range c.Headers {
req.Header.Set(key, value)
}
return c.Transport.RoundTrip(req)
}
// GetModelConfig get OpenAI config
func GetModelConfig() (*openai.ChatModelConfig, error) {
envConfig := &OpenAIInitConfig{
Headers: make(map[string]string),
}
// read from JSON config first
jsonStr := GetEnvConfig(EnvOpenAIInitConfigJSON)
if jsonStr != "" {
if err := json.Unmarshal([]byte(jsonStr), envConfig); err != nil {
return nil, err
}
}
config := &openai.ChatModelConfig{
HTTPClient: &http.Client{
Timeout: defaultTimeout,
Transport: &CustomTransport{
Transport: http.DefaultTransport,
Headers: envConfig.Headers,
},
},
}
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
config.BaseURL = baseURL
} else {
return nil, fmt.Errorf("miss env %s", EnvOpenAIBaseURL)
}
if apiKey := GetEnvConfig(EnvOpenAIAPIKey); apiKey != "" {
config.APIKey = apiKey
} else {
return nil, fmt.Errorf("miss env %s", EnvOpenAIAPIKey)
}
if modelName := GetEnvConfig(EnvModelName); modelName != "" {
config.Model = modelName
} else {
return nil, fmt.Errorf("miss env %s", EnvModelName)
}
// log config info
log.Info().Str("model", config.Model).
Str("baseURL", config.BaseURL).
Str("apiKey", maskAPIKey(config.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return config, nil
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {
return "******"
}
return key[:4] + "******" + key[len(key)-4:]
}
func IsUseVLMUITars() bool {
return GetEnvConfigInBool(EnvUseVLMUITars)
}

View File

@@ -1,9 +1,9 @@
package ai
import "context"
import "github.com/cloudwego/eino/schema"
type ILLMService interface {
Call(ctx context.Context, prompt string) (string, error)
Call(opts *PlanningOptions) (*PlanningResult, error)
}
func NewGPT4oLLMService() (*openaiLLMService, error) {
@@ -12,6 +12,46 @@ func NewGPT4oLLMService() (*openaiLLMService, error) {
type openaiLLMService struct{}
func (s openaiLLMService) Call(ctx context.Context, prompt string) (string, error) {
return "", nil
func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
return nil, nil
}
// PlanningOptions represents the input options for planning
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"`
ConversationHistory []*schema.Message `json:"conversation_history"`
}
// PlanningResult represents the result of planning
type PlanningResult struct {
NextActions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
}
// VLMResponse represents the response from the Vision Language Model
type VLMResponse struct {
Actions []ParsedAction `json:"actions"`
Error string `json:"error,omitempty"`
}
// ParsedAction represents a parsed action from the VLM response
type ParsedAction struct {
ActionType ActionType `json:"actionType"`
ActionInputs map[string]interface{} `json:"actionInputs"`
Thought string `json:"thought"`
}
type ActionType string
const (
ActionTypeClick ActionType = "click"
ActionTypeTap ActionType = "tap"
ActionTypeDrag ActionType = "drag"
ActionTypeSwipe ActionType = "swipe"
ActionTypeWait ActionType = "wait"
ActionTypeFinished ActionType = "finished"
ActionTypeCallUser ActionType = "call_user"
ActionTypeType ActionType = "type"
ActionTypeScroll ActionType = "scroll"
ActionTypeHotkey ActionType = "hotkey"
)

239
uixt/ai/parser.go Normal file
View File

@@ -0,0 +1,239 @@
package ai
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/pkg/errors"
)
// NewActionParser creates a new ActionParser instance
func NewActionParser(factor float64) *ActionParser {
return &ActionParser{
Factor: factor,
}
}
// ActionParser parses VLM responses and converts them to structured actions
type ActionParser struct {
Factor float64 // TODO
}
// Parse parses the prediction text and extracts actions
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
// try parsing JSON format, from VLM like GPT-4o
var jsonActions []ParsedAction
jsonActions, jsonErr := p.parseJSON(predictionText)
if jsonErr == nil {
return jsonActions, nil
}
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
if thoughtErr == nil {
return thoughtActions, nil
}
return nil, errors.Wrap(thoughtErr, "parse planner response failed")
}
// parseJSON tries to parse the response as JSON format
func (p *ActionParser) parseJSON(predictionText string) ([]ParsedAction, error) {
predictionText = strings.TrimSpace(predictionText)
if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") {
predictionText = strings.TrimPrefix(predictionText, "```json")
predictionText = strings.TrimSuffix(predictionText, "```")
}
predictionText = strings.TrimSpace(predictionText)
var response VLMResponse
if err := json.Unmarshal([]byte(predictionText), &response); err != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
}
if response.Error != "" {
return nil, errors.New(response.Error)
}
if len(response.Actions) == 0 {
return nil, errors.New("no actions returned from VLM")
}
// normalize actions
var normalizedActions []ParsedAction
for _, action := range response.Actions {
if err := p.normalizeAction(&action); err != nil {
return nil, errors.Wrap(err, "failed to normalize action")
}
normalizedActions = append(normalizedActions, action)
}
return normalizedActions, nil
}
// parseThoughtAction parses the Thought/Action format response
func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction, error) {
thoughtRegex := regexp.MustCompile(`(?is)Thought:(.+?)Action:`)
actionRegex := regexp.MustCompile(`(?is)Action:(.+)`)
// extract Thought part
thoughtMatch := thoughtRegex.FindStringSubmatch(predictionText)
var thought string
if len(thoughtMatch) > 1 {
thought = strings.TrimSpace(thoughtMatch[1])
}
// extract Action part, e.g. "click(start_box='(552,454)')"
actionMatch := actionRegex.FindStringSubmatch(predictionText)
if len(actionMatch) < 2 {
return nil, errors.New("no action found in the response")
}
actionText := strings.TrimSpace(actionMatch[1])
// parse action type and parameters
return p.parseActionText(actionText, thought)
}
// parseActionText parses the action text to extract the action type and parameters
func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedAction, error) {
// remove trailing comments
if idx := strings.Index(actionText, "#"); idx > 0 {
actionText = strings.TrimSpace(actionText[:idx])
}
// supported action types and regexes
actionRegexes := map[ActionType]*regexp.Regexp{
"click": regexp.MustCompile(`click\(start_box='([^']+)'\)`),
"left_double": regexp.MustCompile(`left_double\(start_box='([^']+)'\)`),
"right_single": regexp.MustCompile(`right_single\(start_box='([^']+)'\)`),
"drag": regexp.MustCompile(`drag\(start_box='([^']+)', end_box='([^']+)'\)`),
"hotkey": regexp.MustCompile(`hotkey\(key='([^']+)'\)`),
"type": regexp.MustCompile(`type\(content='([^']+)'\)`),
"scroll": regexp.MustCompile(`scroll\(start_box='([^']+)', direction='([^']+)'\)`),
"wait": regexp.MustCompile(`wait\(\)`),
"finished": regexp.MustCompile(`finished\(\)`),
"call_user": regexp.MustCompile(`call_user\(\)`),
}
parsedActions := make([]ParsedAction, 0)
for actionType, regex := range actionRegexes {
matches := regex.FindStringSubmatch(actionText)
if len(matches) == 0 {
continue
}
var action ParsedAction
action.ActionType = actionType
action.ActionInputs = make(map[string]interface{})
action.Thought = thought
// parse parameters based on action type
switch actionType {
case ActionTypeClick:
if len(matches) > 1 {
coord, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize point failed: %s", matches[1])
}
action.ActionInputs["startBox"] = coord
}
case ActionTypeDrag:
if len(matches) > 2 {
// handle start point
startBox, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
}
action.ActionInputs["startBox"] = startBox
// handle end point
endBox, err := p.normalizeCoordinates(matches[2])
if err != nil {
return nil, errors.Wrapf(err, "normalize endBox failed: %s", matches[2])
}
action.ActionInputs["endBox"] = endBox
}
case ActionTypeHotkey:
if len(matches) > 1 {
action.ActionInputs["key"] = matches[1]
}
case ActionTypeType:
if len(matches) > 1 {
action.ActionInputs["content"] = matches[1]
}
case ActionTypeScroll:
if len(matches) > 2 {
startBox, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
}
action.ActionInputs["startBox"] = startBox
action.ActionInputs["direction"] = matches[2]
}
case ActionTypeWait, ActionTypeFinished, ActionTypeCallUser:
// 这些动作没有额外参数
}
parsedActions = append(parsedActions, action)
}
if len(parsedActions) == 0 {
return nil, fmt.Errorf("no valid actions returned from VLM")
}
return parsedActions, nil
}
// normalizeAction normalizes the coordinates in the action
func (p *ActionParser) normalizeAction(action *ParsedAction) error {
switch action.ActionType {
case "click", "drag":
// handle click and drag action coordinates
if startBox, ok := action.ActionInputs["startBox"].(string); ok {
normalized, err := p.normalizeCoordinates(startBox)
if err != nil {
return fmt.Errorf("failed to normalize startBox: %w", err)
}
action.ActionInputs["startBox"] = normalized
}
if endBox, ok := action.ActionInputs["endBox"].(string); ok {
normalized, err := p.normalizeCoordinates(endBox)
if err != nil {
return fmt.Errorf("failed to normalize endBox: %w", err)
}
action.ActionInputs["endBox"] = normalized
}
}
return nil
}
// normalizeCoordinates normalizes the coordinates based on the factor
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
// check empty string
if coordStr == "" {
return nil, fmt.Errorf("empty coordinate string")
}
if !strings.Contains(coordStr, ",") {
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
}
// remove possible brackets and split coordinates
coordStr = strings.Trim(coordStr, "[]() \t")
// try parsing JSON array
jsonStr := coordStr
if !strings.HasPrefix(jsonStr, "[") {
jsonStr = "[" + coordStr + "]"
}
err = json.Unmarshal([]byte(jsonStr), &coords)
if err != nil {
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
}
return coords, nil
}

352
uixt/ai/planner.go Normal file
View File

@@ -0,0 +1,352 @@
package ai
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
"image/png"
"os"
"strings"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/schema"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
// Error types
var (
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
ErrInvalidImageData = fmt.Errorf("invalid image data")
)
func NewPlanner(ctx context.Context) (*Planner, error) {
config, err := GetModelConfig()
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
}
model, err := openai.NewChatModel(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
}
parser := NewActionParser(1000)
return &Planner{
ctx: ctx,
model: model,
parser: parser,
}, nil
}
type Planner struct {
ctx context.Context
model *openai.ChatModel
parser *ActionParser
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
}
// call VLM service
resp, err := p.callVLMService(opts)
if err != nil {
return nil, errors.Wrap(err, "call VLM service failed")
}
// parse result
result, err := p.parseResult(resp)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
}
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
return result, nil
}
func validateInput(opts *PlanningOptions) error {
if opts.UserInstruction == "" {
return ErrEmptyInstruction
}
if len(opts.ConversationHistory) == 0 {
return ErrNoConversationHistory
}
// ensure at least one image URL
hasImageURL := false
for _, msg := range opts.ConversationHistory {
if msg.Role == "user" {
// check MultiContent
if len(msg.MultiContent) > 0 {
for _, content := range msg.MultiContent {
if content.Type == "image_url" && content.ImageURL != nil {
hasImageURL = true
break
}
}
}
}
if hasImageURL {
break
}
}
if !hasImageURL {
return ErrInvalidImageData
}
return nil
}
// callVLMService makes the actual call to the VLM service
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
log.Info().Msg("calling VLM service...")
// prepare prompt
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
messages := []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
}
messages = append(messages, opts.ConversationHistory...)
// generate response
resp, err := p.model.Generate(p.ctx, messages)
if err != nil {
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
}
log.Info().Str("content", resp.Content).Msg("get VLM response")
return resp, nil
}
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
// parse response
actions, err := p.parser.Parse(msg.Content)
if err != nil {
return nil, fmt.Errorf("failed to parse actions: %w", err)
}
// process response
result, err := processVLMResponse(actions)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
return result, nil
}
// processVLMResponse processes the VLM response and converts it to PlanningResult
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
log.Info().Msg("processing VLM response...")
if len(actions) == 0 {
return nil, fmt.Errorf("no actions returned from VLM")
}
// validate and post-process each action
for i := range actions {
// validate action type
switch actions[i].ActionType {
case "click", "left_double", "right_single":
validateCoordinateAction(&actions[i], "startBox")
case "drag":
validateCoordinateAction(&actions[i], "startBox")
validateCoordinateAction(&actions[i], "endBox")
case "scroll":
validateCoordinateAction(&actions[i], "startBox")
validateScrollDirection(&actions[i])
case "type":
validateTypeContent(&actions[i])
case "hotkey":
validateHotkeyAction(&actions[i])
case "wait", "finished", "call_user":
// these actions do not need extra parameters
default:
log.Printf("warning: unknown action type: %s, will try to continue processing", actions[i].ActionType)
}
}
// extract action summary
actionSummary := extractActionSummary(actions)
return &PlanningResult{
NextActions: actions,
ActionSummary: actionSummary,
}, nil
}
// extractActionSummary extracts the summary from the actions
func extractActionSummary(actions []ParsedAction) string {
if len(actions) == 0 {
return ""
}
// use the Thought of the first action as summary
if actions[0].Thought != "" {
return actions[0].Thought
}
// if no Thought, generate summary from action type
action := actions[0]
switch action.ActionType {
case "click":
return "点击操作"
case "drag":
return "拖拽操作"
case "left_double":
return "双击操作"
case "right_single":
return "右键点击操作"
case "scroll":
direction, _ := action.ActionInputs["direction"].(string)
return fmt.Sprintf("滚动操作 (%s)", direction)
case "type":
content, _ := action.ActionInputs["content"].(string)
if len(content) > 20 {
content = content[:20] + "..."
}
return fmt.Sprintf("输入文本: %s", content)
case "hotkey":
key, _ := action.ActionInputs["key"].(string)
return fmt.Sprintf("快捷键: %s", key)
case "wait":
return "等待操作"
case "finished":
return "完成操作"
case "call_user":
return "请求用户协助"
default:
return fmt.Sprintf("执行 %s 操作", action.ActionType)
}
}
// validateCoordinateAction 验证坐标类动作
func validateCoordinateAction(action *ParsedAction, boxField string) {
// TODO
}
// validateScrollDirection 验证滚动方向
func validateScrollDirection(action *ParsedAction) {
if direction, ok := action.ActionInputs["direction"].(string); !ok || direction == "" {
// default to down
action.ActionInputs["direction"] = "down"
} else {
switch strings.ToLower(direction) {
case "up", "down", "left", "right":
// keep original direction
default:
action.ActionInputs["direction"] = "down"
log.Warn().Str("direction", direction).Msg("invalid scroll direction, set to default")
}
}
}
// validateTypeContent 验证输入文本内容
func validateTypeContent(action *ParsedAction) {
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
// default to empty string
action.ActionInputs["content"] = ""
log.Warn().Msg("type action missing content parameter, set to default")
}
}
// validateHotkeyAction 验证快捷键动作
func validateHotkeyAction(action *ParsedAction) {
if key, ok := action.ActionInputs["key"]; !ok || key == "" {
// 为空或缺失的键设置默认值
action.ActionInputs["key"] = "Enter"
log.Printf("警告: hotkey动作缺少key参数, 已设置默认值")
}
}
// SavePositionImg saves an image with position markers
func SavePositionImg(params struct {
InputImgBase64 string
Rect struct {
X float64
Y float64
}
OutputPath string
}) error {
// 解码Base64图像
imgData := params.InputImgBase64
// 如果包含了数据URL前缀去掉它
if strings.HasPrefix(imgData, "data:image/") {
parts := strings.Split(imgData, ",")
if len(parts) > 1 {
imgData = parts[1]
}
}
// 解码Base64
unbased, err := base64.StdEncoding.DecodeString(imgData)
if err != nil {
return fmt.Errorf("无法解码Base64图像: %w", err)
}
// 解码图像
reader := bytes.NewReader(unbased)
img, _, err := image.Decode(reader)
if err != nil {
return fmt.Errorf("无法解码图像数据: %w", err)
}
// 创建一个可以在其上绘制的图像
bounds := img.Bounds()
rgba := image.NewRGBA(bounds)
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
// 在点击/拖动位置绘制标记
markRadius := 10
x, y := int(params.Rect.X), int(params.Rect.Y)
// 绘制红色圆圈
for i := -markRadius; i <= markRadius; i++ {
for j := -markRadius; j <= markRadius; j++ {
if i*i+j*j <= markRadius*markRadius {
if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y {
rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255})
}
}
}
}
// 保存图像
outFile, err := os.Create(params.OutputPath)
if err != nil {
return fmt.Errorf("无法创建输出文件: %w", err)
}
defer outFile.Close()
// 编码为PNG并保存
if err := png.Encode(outFile, rgba); err != nil {
return fmt.Errorf("无法编码和保存图像: %w", err)
}
return nil
}
// loadImage loads image and returns base64 encoded string
func loadImage(imagePath string) (base64Str string, err error) {
imageData, err := os.ReadFile(imagePath)
if err != nil {
return "", err
}
base64Str = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
return
}

277
uixt/ai/planner_test.go Normal file
View File

@@ -0,0 +1,277 @@
package ai
import (
"context"
"os"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVLMPlanning(t *testing.T) {
imageBase64, err := loadImage("testdata/llk_3.jpg")
require.NoError(t, err)
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
1. 游戏目标: 玩家需要在规定时间内,通过连接相同的图案或图标,将它们从游戏界面中消除。
2. 连接规则:
- 两个相同的图案可以通过不超过三条直线连接。
- 连接线可以水平或垂直,但不能穿过其他图案。
- 连接线的转折次数不能超过两次。
3. 游戏界面: 游戏界面通常是一个矩形区域,内含多个图案或图标,排列成行和列。
4. 时间限制: 游戏通常设有时间限制,玩家需要在时间耗尽前完成所有图案的消除。
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。`
// userInstruction += "\n\n请基于以上游戏规则请先点击第一个图标"
userInstruction += "\n\n点击[3排1列]果酱图案"
planner, err := NewPlanner(context.Background())
require.NoError(t, err)
opts := &PlanningOptions{
UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
},
},
}
// 执行规划
result, err := planner.Call(opts)
// 验证结果
require.NoError(t, err)
require.NotNil(t, result)
require.NotEmpty(t, result.NextActions)
// 验证动作
action := result.NextActions[0]
assert.NotEmpty(t, action.ActionType)
assert.NotEmpty(t, action.Thought)
// 根据动作类型验证参数
switch action.ActionType {
case ActionTypeClick:
// 这些动作需要验证坐标
assert.NotEmpty(t, action.ActionInputs["startBox"])
// 验证坐标格式
coords, ok := action.ActionInputs["startBox"].([]float64)
require.True(t, ok)
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
// 验证坐标范围
for _, coord := range coords {
assert.GreaterOrEqual(t, coord, float64(0))
}
case "wait", "finished", "call_user":
// 这些动作不需要额外参数
default:
t.Fatalf("未知的动作类型: %s", action.ActionType)
}
}
func TestXHSPlanning(t *testing.T) {
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
require.NoError(t, err)
userInstruction := `点击第二个帖子的作者头像`
planner, err := NewPlanner(context.Background())
require.NoError(t, err)
opts := &PlanningOptions{
UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
},
},
}
// 执行规划
result, err := planner.Call(opts)
// 验证结果
require.NoError(t, err)
require.NotNil(t, result)
require.NotEmpty(t, result.NextActions)
// 验证动作
action := result.NextActions[0]
assert.NotEmpty(t, action.ActionType)
assert.NotEmpty(t, action.Thought)
}
func TestValidateInput(t *testing.T) {
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
tests := []struct {
name string
opts *PlanningOptions
wantErr error
}{
{
name: "valid input",
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
},
},
},
wantErr: nil,
},
{
name: "empty instruction",
opts: &PlanningOptions{
UserInstruction: "",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
Content: "",
},
},
},
wantErr: ErrEmptyInstruction,
},
{
name: "empty conversation history",
opts: &PlanningOptions{
UserInstruction: "点击立即卸载按钮",
ConversationHistory: []*schema.Message{},
},
wantErr: ErrNoConversationHistory,
},
{
name: "invalid image data",
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
Content: "no image",
},
},
},
wantErr: ErrInvalidImageData,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateInput(tt.opts)
if tt.wantErr != nil {
assert.Error(t, err)
assert.Equal(t, tt.wantErr, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestProcessVLMResponse(t *testing.T) {
tests := []struct {
name string
actions []ParsedAction
wantErr bool
}{
{
name: "valid response",
actions: []ParsedAction{
{
ActionType: "click",
ActionInputs: map[string]interface{}{
"startBox": []float64{0.5, 0.5},
},
Thought: "点击中心位置",
},
},
wantErr: false,
},
{
name: "empty actions",
actions: []ParsedAction{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := processVLMResponse(tt.actions)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, result)
return
}
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, tt.actions, result.NextActions)
})
}
}
func TestSavePositionImg(t *testing.T) {
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
params := struct {
InputImgBase64 string
Rect struct {
X float64
Y float64
}
OutputPath string
}{
InputImgBase64: imageBase64,
Rect: struct {
X float64
Y float64
}{
X: 100,
Y: 100,
},
OutputPath: "testdata/output.png",
}
err = SavePositionImg(params)
require.NoError(t, err)
// cleanup
defer os.Remove(params.OutputPath)
}

27
uixt/ai/prompt-ui-tars.go Normal file
View File

@@ -0,0 +1,27 @@
package ai
const uiTarsPlanningPrompt = `
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
Thought: ...
Action: ...
## Action Space
click(start_box='[x1, y1, x2, y2]')
left_double(start_box='[x1, y1, x2, y2]')
right_single(start_box='[x1, y1, x2, y2]')
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
hotkey(key='')
type(content='') #If you want to submit your input, use "\n" at the end of content.
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use Chinese in Thought part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
## User Instruction
`

BIN
uixt/ai/testdata/1.jpeg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 393 KiB

BIN
uixt/ai/testdata/2.jpeg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

BIN
uixt/ai/testdata/llk_1.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 437 KiB

BIN
uixt/ai/testdata/llk_2.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 407 KiB

BIN
uixt/ai/testdata/llk_3.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

BIN
uixt/ai/testdata/popup_risk_warning.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

BIN
uixt/ai/testdata/xhs-feed.jpeg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 649 KiB

View File

@@ -24,7 +24,9 @@ func setupADBDriverExt(t *testing.T) *XTDriver {
driver, err := device.NewDriver()
require.Nil(t, err)
return NewXTDriver(driver,
ai.WithCVService(ai.CVServiceTypeVEDEM))
ai.WithCVService(ai.CVServiceTypeVEDEM),
ai.WithLLMService(ai.LLMServiceTypeUITARS),
)
}
func setupUIA2DriverExt(t *testing.T) *XTDriver {

42
uixt/driver_ext_ai.go Normal file
View File

@@ -0,0 +1,42 @@
package uixt
import (
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors"
)
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
if dExt.LLMService == nil {
return nil, errors.New("LLM service is not initialized")
}
screenShotBase64, err := dExt.GetScreenShotBase64()
if err != nil {
return nil, err
}
planningOpts := &ai.PlanningOptions{
UserInstruction: text,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: screenShotBase64,
},
},
},
},
},
}
result, err := dExt.LLMService.Call(planningOpts)
if err != nil {
return nil, errors.Wrap(err, "failed to get next action from planner")
}
return result, nil
}

View File

@@ -8,6 +8,22 @@ import (
"github.com/rs/zerolog/log"
)
func (dExt *XTDriver) TapByLLM(text string, opts ...option.ActionOption) error {
text = "[click] " + text
result, err := dExt.PlanNextAction(text, opts...)
if err != nil {
return err
}
action := result.NextActions[0]
if action.ActionType != ai.ActionTypeClick {
return fmt.Errorf("expected click action, got: %s", action.ActionType)
}
point := action.ActionInputs["startBox"].([]float64)
return dExt.TapAbsXY(point[0], point[1], opts...)
}
func (dExt *XTDriver) TapByOCR(text string, opts ...option.ActionOption) error {
actionOptions := option.NewActionOptions(opts...)
if actionOptions.ScreenShotFileName == "" {

View File

@@ -123,6 +123,19 @@ func TestDriverExt_TapByOCR(t *testing.T) {
assert.Nil(t, err)
}
func TestDriverExt_TapByLLM(t *testing.T) {
driver := setupDriverExt(t)
err := driver.TapByLLM("点击第一个帖子的作者头像")
assert.Nil(t, err)
}
func TestDriverExt_PlanNextAction(t *testing.T) {
driver := setupDriverExt(t)
result, err := driver.PlanNextAction("启动抖音")
assert.Nil(t, err)
t.Log(result)
}
func TestDriverExt_prepareSwipeAction(t *testing.T) {
driver := setupDriverExt(t)