feat: add TapByLLM/PlanNextAction for XTDriver
@@ -1,6 +1,7 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -46,6 +47,7 @@ func WithCVService(service CVServiceType) AIServiceOption {
|
||||
type LLMServiceType string
|
||||
|
||||
const (
|
||||
LLMServiceTypeUITARS LLMServiceType = "ui-tars"
|
||||
LLMServiceTypeGPT4o LLMServiceType = "gpt-4o"
|
||||
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
|
||||
)
|
||||
@@ -60,5 +62,12 @@ func WithLLMService(service LLMServiceType) AIServiceOption {
|
||||
os.Exit(code.GetErrorCode(err))
|
||||
}
|
||||
}
|
||||
if service == LLMServiceTypeUITARS {
|
||||
var err error
|
||||
opts.ILLMService, err = NewPlanner(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init ui-tars llm service failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import "testing"
|
||||
func TestOption(t *testing.T) {
|
||||
options := NewAIService(
|
||||
WithCVService(CVServiceTypeOpenCV),
|
||||
WithLLMService(LLMServiceTypeDeepSeekV3),
|
||||
WithLLMService(LLMServiceTypeUITARS),
|
||||
)
|
||||
t.Log(options)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ type APIResponseImage struct {
|
||||
}
|
||||
|
||||
func NewVEDEMImageService() (*vedemCVService, error) {
|
||||
if err := checkEnv(); err != nil {
|
||||
if err := checkEnvCV(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &vedemCVService{}, nil
|
||||
@@ -230,7 +230,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A
|
||||
return imageResult, nil
|
||||
}
|
||||
|
||||
func checkEnv() error {
|
||||
func checkEnvCV() error {
|
||||
vedemImageURL := os.Getenv("VEDEM_IMAGE_URL")
|
||||
if vedemImageURL == "" {
|
||||
return errors.Wrap(code.CVEnvMissedError, "VEDEM_IMAGE_URL missed")
|
||||
|
||||
170
uixt/ai/env.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type OpenAIInitConfig struct {
|
||||
ReportURL string `json:"REPORT_SERVER_URL"`
|
||||
Headers map[string]string `json:"defaultHeaders"`
|
||||
}
|
||||
|
||||
const (
|
||||
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
|
||||
EnvOpenAIAPIKey = "OPENAI_API_KEY"
|
||||
EnvModelName = "MIDSCENE_MODEL_NAME"
|
||||
EnvOpenAIInitConfigJSON = "MIDSCENE_OPENAI_INIT_CONFIG_JSON"
|
||||
EnvUseVLMUITars = "MIDSCENE_USE_VLM_UI_TARS"
|
||||
)
|
||||
|
||||
// loadEnv loads environment variables from a file
|
||||
func loadEnv(envPath string) error {
|
||||
err := godotenv.Load(envPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Str("path", envPath).Msg("load env success")
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetEnvConfig(key string) string {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
|
||||
func GetEnvConfigInJSON(key string) (map[string]interface{}, error) {
|
||||
value := GetEnvConfig(key)
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(value), &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func GetEnvConfigInBool(key string) bool {
|
||||
value := GetEnvConfig(key)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
boolValue, _ := strconv.ParseBool(value)
|
||||
return boolValue
|
||||
}
|
||||
|
||||
// GetEnvConfigOrDefault get env config or default value
|
||||
func GetEnvConfigOrDefault(key, defaultValue string) string {
|
||||
value := GetEnvConfig(key)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func GetEnvConfigInInt(key string, defaultValue int) int {
|
||||
value := GetEnvConfig(key)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intValue, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intValue
|
||||
}
|
||||
|
||||
// CustomTransport is a custom RoundTripper that adds headers to every request
|
||||
type CustomTransport struct {
|
||||
Transport http.RoundTripper
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction and adds custom headers
|
||||
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for key, value := range c.Headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
return c.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// GetModelConfig get OpenAI config
|
||||
func GetModelConfig() (*openai.ChatModelConfig, error) {
|
||||
envConfig := &OpenAIInitConfig{
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
|
||||
// read from JSON config first
|
||||
jsonStr := GetEnvConfig(EnvOpenAIInitConfigJSON)
|
||||
if jsonStr != "" {
|
||||
if err := json.Unmarshal([]byte(jsonStr), envConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
config := &openai.ChatModelConfig{
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: &CustomTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Headers: envConfig.Headers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
|
||||
config.BaseURL = baseURL
|
||||
} else {
|
||||
return nil, fmt.Errorf("miss env %s", EnvOpenAIBaseURL)
|
||||
}
|
||||
|
||||
if apiKey := GetEnvConfig(EnvOpenAIAPIKey); apiKey != "" {
|
||||
config.APIKey = apiKey
|
||||
} else {
|
||||
return nil, fmt.Errorf("miss env %s", EnvOpenAIAPIKey)
|
||||
}
|
||||
|
||||
if modelName := GetEnvConfig(EnvModelName); modelName != "" {
|
||||
config.Model = modelName
|
||||
} else {
|
||||
return nil, fmt.Errorf("miss env %s", EnvModelName)
|
||||
}
|
||||
|
||||
// log config info
|
||||
log.Info().Str("model", config.Model).
|
||||
Str("baseURL", config.BaseURL).
|
||||
Str("apiKey", maskAPIKey(config.APIKey)).
|
||||
Str("timeout", defaultTimeout.String()).
|
||||
Msg("get model config")
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// maskAPIKey masks the API key
|
||||
func maskAPIKey(key string) string {
|
||||
if len(key) <= 8 {
|
||||
return "******"
|
||||
}
|
||||
|
||||
return key[:4] + "******" + key[len(key)-4:]
|
||||
}
|
||||
|
||||
func IsUseVLMUITars() bool {
|
||||
return GetEnvConfigInBool(EnvUseVLMUITars)
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
package ai
|
||||
|
||||
import "context"
|
||||
import "github.com/cloudwego/eino/schema"
|
||||
|
||||
type ILLMService interface {
|
||||
Call(ctx context.Context, prompt string) (string, error)
|
||||
Call(opts *PlanningOptions) (*PlanningResult, error)
|
||||
}
|
||||
|
||||
func NewGPT4oLLMService() (*openaiLLMService, error) {
|
||||
@@ -12,6 +12,46 @@ func NewGPT4oLLMService() (*openaiLLMService, error) {
|
||||
|
||||
type openaiLLMService struct{}
|
||||
|
||||
func (s openaiLLMService) Call(ctx context.Context, prompt string) (string, error) {
|
||||
return "", nil
|
||||
func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// PlanningOptions represents the input options for planning
|
||||
type PlanningOptions struct {
|
||||
UserInstruction string `json:"user_instruction"`
|
||||
ConversationHistory []*schema.Message `json:"conversation_history"`
|
||||
}
|
||||
|
||||
// PlanningResult represents the result of planning
|
||||
type PlanningResult struct {
|
||||
NextActions []ParsedAction `json:"actions"`
|
||||
ActionSummary string `json:"summary"`
|
||||
}
|
||||
|
||||
// VLMResponse represents the response from the Vision Language Model
|
||||
type VLMResponse struct {
|
||||
Actions []ParsedAction `json:"actions"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ParsedAction represents a parsed action from the VLM response
|
||||
type ParsedAction struct {
|
||||
ActionType ActionType `json:"actionType"`
|
||||
ActionInputs map[string]interface{} `json:"actionInputs"`
|
||||
Thought string `json:"thought"`
|
||||
}
|
||||
|
||||
type ActionType string
|
||||
|
||||
const (
|
||||
ActionTypeClick ActionType = "click"
|
||||
ActionTypeTap ActionType = "tap"
|
||||
ActionTypeDrag ActionType = "drag"
|
||||
ActionTypeSwipe ActionType = "swipe"
|
||||
ActionTypeWait ActionType = "wait"
|
||||
ActionTypeFinished ActionType = "finished"
|
||||
ActionTypeCallUser ActionType = "call_user"
|
||||
ActionTypeType ActionType = "type"
|
||||
ActionTypeScroll ActionType = "scroll"
|
||||
ActionTypeHotkey ActionType = "hotkey"
|
||||
)
|
||||
|
||||
239
uixt/ai/parser.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// NewActionParser creates a new ActionParser instance
|
||||
func NewActionParser(factor float64) *ActionParser {
|
||||
return &ActionParser{
|
||||
Factor: factor,
|
||||
}
|
||||
}
|
||||
|
||||
// ActionParser parses VLM responses and converts them to structured actions
|
||||
type ActionParser struct {
|
||||
Factor float64 // TODO
|
||||
}
|
||||
|
||||
// Parse parses the prediction text and extracts actions
|
||||
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
|
||||
// try parsing JSON format, from VLM like GPT-4o
|
||||
var jsonActions []ParsedAction
|
||||
jsonActions, jsonErr := p.parseJSON(predictionText)
|
||||
if jsonErr == nil {
|
||||
return jsonActions, nil
|
||||
}
|
||||
|
||||
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
|
||||
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
|
||||
if thoughtErr == nil {
|
||||
return thoughtActions, nil
|
||||
}
|
||||
|
||||
return nil, errors.Wrap(thoughtErr, "parse planner response failed")
|
||||
}
|
||||
|
||||
// parseJSON tries to parse the response as JSON format
|
||||
func (p *ActionParser) parseJSON(predictionText string) ([]ParsedAction, error) {
|
||||
predictionText = strings.TrimSpace(predictionText)
|
||||
if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") {
|
||||
predictionText = strings.TrimPrefix(predictionText, "```json")
|
||||
predictionText = strings.TrimSuffix(predictionText, "```")
|
||||
}
|
||||
predictionText = strings.TrimSpace(predictionText)
|
||||
|
||||
var response VLMResponse
|
||||
if err := json.Unmarshal([]byte(predictionText), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
|
||||
}
|
||||
|
||||
if response.Error != "" {
|
||||
return nil, errors.New(response.Error)
|
||||
}
|
||||
|
||||
if len(response.Actions) == 0 {
|
||||
return nil, errors.New("no actions returned from VLM")
|
||||
}
|
||||
|
||||
// normalize actions
|
||||
var normalizedActions []ParsedAction
|
||||
for _, action := range response.Actions {
|
||||
if err := p.normalizeAction(&action); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to normalize action")
|
||||
}
|
||||
normalizedActions = append(normalizedActions, action)
|
||||
}
|
||||
|
||||
return normalizedActions, nil
|
||||
}
|
||||
|
||||
// parseThoughtAction parses the Thought/Action format response
|
||||
func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction, error) {
|
||||
thoughtRegex := regexp.MustCompile(`(?is)Thought:(.+?)Action:`)
|
||||
actionRegex := regexp.MustCompile(`(?is)Action:(.+)`)
|
||||
|
||||
// extract Thought part
|
||||
thoughtMatch := thoughtRegex.FindStringSubmatch(predictionText)
|
||||
var thought string
|
||||
if len(thoughtMatch) > 1 {
|
||||
thought = strings.TrimSpace(thoughtMatch[1])
|
||||
}
|
||||
|
||||
// extract Action part, e.g. "click(start_box='(552,454)')"
|
||||
actionMatch := actionRegex.FindStringSubmatch(predictionText)
|
||||
if len(actionMatch) < 2 {
|
||||
return nil, errors.New("no action found in the response")
|
||||
}
|
||||
|
||||
actionText := strings.TrimSpace(actionMatch[1])
|
||||
|
||||
// parse action type and parameters
|
||||
return p.parseActionText(actionText, thought)
|
||||
}
|
||||
|
||||
// parseActionText parses the action text to extract the action type and parameters
|
||||
func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedAction, error) {
|
||||
// remove trailing comments
|
||||
if idx := strings.Index(actionText, "#"); idx > 0 {
|
||||
actionText = strings.TrimSpace(actionText[:idx])
|
||||
}
|
||||
|
||||
// supported action types and regexes
|
||||
actionRegexes := map[ActionType]*regexp.Regexp{
|
||||
"click": regexp.MustCompile(`click\(start_box='([^']+)'\)`),
|
||||
"left_double": regexp.MustCompile(`left_double\(start_box='([^']+)'\)`),
|
||||
"right_single": regexp.MustCompile(`right_single\(start_box='([^']+)'\)`),
|
||||
"drag": regexp.MustCompile(`drag\(start_box='([^']+)', end_box='([^']+)'\)`),
|
||||
"hotkey": regexp.MustCompile(`hotkey\(key='([^']+)'\)`),
|
||||
"type": regexp.MustCompile(`type\(content='([^']+)'\)`),
|
||||
"scroll": regexp.MustCompile(`scroll\(start_box='([^']+)', direction='([^']+)'\)`),
|
||||
"wait": regexp.MustCompile(`wait\(\)`),
|
||||
"finished": regexp.MustCompile(`finished\(\)`),
|
||||
"call_user": regexp.MustCompile(`call_user\(\)`),
|
||||
}
|
||||
|
||||
parsedActions := make([]ParsedAction, 0)
|
||||
for actionType, regex := range actionRegexes {
|
||||
matches := regex.FindStringSubmatch(actionText)
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var action ParsedAction
|
||||
action.ActionType = actionType
|
||||
action.ActionInputs = make(map[string]interface{})
|
||||
action.Thought = thought
|
||||
|
||||
// parse parameters based on action type
|
||||
switch actionType {
|
||||
case ActionTypeClick:
|
||||
if len(matches) > 1 {
|
||||
coord, err := p.normalizeCoordinates(matches[1])
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "normalize point failed: %s", matches[1])
|
||||
}
|
||||
action.ActionInputs["startBox"] = coord
|
||||
}
|
||||
case ActionTypeDrag:
|
||||
if len(matches) > 2 {
|
||||
// handle start point
|
||||
startBox, err := p.normalizeCoordinates(matches[1])
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
|
||||
}
|
||||
action.ActionInputs["startBox"] = startBox
|
||||
|
||||
// handle end point
|
||||
endBox, err := p.normalizeCoordinates(matches[2])
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "normalize endBox failed: %s", matches[2])
|
||||
}
|
||||
action.ActionInputs["endBox"] = endBox
|
||||
}
|
||||
case ActionTypeHotkey:
|
||||
if len(matches) > 1 {
|
||||
action.ActionInputs["key"] = matches[1]
|
||||
}
|
||||
case ActionTypeType:
|
||||
if len(matches) > 1 {
|
||||
action.ActionInputs["content"] = matches[1]
|
||||
}
|
||||
case ActionTypeScroll:
|
||||
if len(matches) > 2 {
|
||||
startBox, err := p.normalizeCoordinates(matches[1])
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
|
||||
}
|
||||
action.ActionInputs["startBox"] = startBox
|
||||
action.ActionInputs["direction"] = matches[2]
|
||||
}
|
||||
case ActionTypeWait, ActionTypeFinished, ActionTypeCallUser:
|
||||
// 这些动作没有额外参数
|
||||
}
|
||||
|
||||
parsedActions = append(parsedActions, action)
|
||||
}
|
||||
|
||||
if len(parsedActions) == 0 {
|
||||
return nil, fmt.Errorf("no valid actions returned from VLM")
|
||||
}
|
||||
return parsedActions, nil
|
||||
}
|
||||
|
||||
// normalizeAction normalizes the coordinates in the action
|
||||
func (p *ActionParser) normalizeAction(action *ParsedAction) error {
|
||||
switch action.ActionType {
|
||||
case "click", "drag":
|
||||
// handle click and drag action coordinates
|
||||
if startBox, ok := action.ActionInputs["startBox"].(string); ok {
|
||||
normalized, err := p.normalizeCoordinates(startBox)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize startBox: %w", err)
|
||||
}
|
||||
action.ActionInputs["startBox"] = normalized
|
||||
}
|
||||
|
||||
if endBox, ok := action.ActionInputs["endBox"].(string); ok {
|
||||
normalized, err := p.normalizeCoordinates(endBox)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize endBox: %w", err)
|
||||
}
|
||||
action.ActionInputs["endBox"] = normalized
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeCoordinates normalizes the coordinates based on the factor
|
||||
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
|
||||
// check empty string
|
||||
if coordStr == "" {
|
||||
return nil, fmt.Errorf("empty coordinate string")
|
||||
}
|
||||
|
||||
if !strings.Contains(coordStr, ",") {
|
||||
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
|
||||
}
|
||||
|
||||
// remove possible brackets and split coordinates
|
||||
coordStr = strings.Trim(coordStr, "[]() \t")
|
||||
|
||||
// try parsing JSON array
|
||||
jsonStr := coordStr
|
||||
if !strings.HasPrefix(jsonStr, "[") {
|
||||
jsonStr = "[" + coordStr + "]"
|
||||
}
|
||||
|
||||
err = json.Unmarshal([]byte(jsonStr), &coords)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
|
||||
}
|
||||
return coords, nil
|
||||
}
|
||||
352
uixt/ai/planner.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Error types
|
||||
var (
|
||||
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
|
||||
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
|
||||
ErrInvalidImageData = fmt.Errorf("invalid image data")
|
||||
)
|
||||
|
||||
func NewPlanner(ctx context.Context) (*Planner, error) {
|
||||
config, err := GetModelConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
|
||||
}
|
||||
model, err := openai.NewChatModel(ctx, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
|
||||
}
|
||||
parser := NewActionParser(1000)
|
||||
return &Planner{
|
||||
ctx: ctx,
|
||||
model: model,
|
||||
parser: parser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Planner struct {
|
||||
ctx context.Context
|
||||
model *openai.ChatModel
|
||||
parser *ActionParser
|
||||
}
|
||||
|
||||
// Call performs UI planning using Vision Language Model
|
||||
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
||||
|
||||
// validate input parameters
|
||||
if err := validateInput(opts); err != nil {
|
||||
return nil, errors.Wrap(err, "validate input parameters failed")
|
||||
}
|
||||
|
||||
// call VLM service
|
||||
resp, err := p.callVLMService(opts)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "call VLM service failed")
|
||||
}
|
||||
|
||||
// parse result
|
||||
result, err := p.parseResult(resp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse result failed")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Interface("summary", result.ActionSummary).
|
||||
Interface("actions", result.NextActions).
|
||||
Msg("get VLM planning result")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func validateInput(opts *PlanningOptions) error {
|
||||
if opts.UserInstruction == "" {
|
||||
return ErrEmptyInstruction
|
||||
}
|
||||
|
||||
if len(opts.ConversationHistory) == 0 {
|
||||
return ErrNoConversationHistory
|
||||
}
|
||||
|
||||
// ensure at least one image URL
|
||||
hasImageURL := false
|
||||
for _, msg := range opts.ConversationHistory {
|
||||
if msg.Role == "user" {
|
||||
// check MultiContent
|
||||
if len(msg.MultiContent) > 0 {
|
||||
for _, content := range msg.MultiContent {
|
||||
if content.Type == "image_url" && content.ImageURL != nil {
|
||||
hasImageURL = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasImageURL {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasImageURL {
|
||||
return ErrInvalidImageData
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callVLMService makes the actual call to the VLM service
|
||||
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
|
||||
log.Info().Msg("calling VLM service...")
|
||||
|
||||
// prepare prompt
|
||||
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
||||
messages := []*schema.Message{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
}
|
||||
messages = append(messages, opts.ConversationHistory...)
|
||||
|
||||
// generate response
|
||||
resp, err := p.model.Generate(p.ctx, messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
||||
}
|
||||
log.Info().Str("content", resp.Content).Msg("get VLM response")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
|
||||
// parse response
|
||||
actions, err := p.parser.Parse(msg.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse actions: %w", err)
|
||||
}
|
||||
|
||||
// process response
|
||||
result, err := processVLMResponse(actions)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "process VLM response failed")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// processVLMResponse processes the VLM response and converts it to PlanningResult
|
||||
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
|
||||
log.Info().Msg("processing VLM response...")
|
||||
|
||||
if len(actions) == 0 {
|
||||
return nil, fmt.Errorf("no actions returned from VLM")
|
||||
}
|
||||
|
||||
// validate and post-process each action
|
||||
for i := range actions {
|
||||
// validate action type
|
||||
switch actions[i].ActionType {
|
||||
case "click", "left_double", "right_single":
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
case "drag":
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
validateCoordinateAction(&actions[i], "endBox")
|
||||
case "scroll":
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
validateScrollDirection(&actions[i])
|
||||
case "type":
|
||||
validateTypeContent(&actions[i])
|
||||
case "hotkey":
|
||||
validateHotkeyAction(&actions[i])
|
||||
case "wait", "finished", "call_user":
|
||||
// these actions do not need extra parameters
|
||||
default:
|
||||
log.Printf("warning: unknown action type: %s, will try to continue processing", actions[i].ActionType)
|
||||
}
|
||||
}
|
||||
|
||||
// extract action summary
|
||||
actionSummary := extractActionSummary(actions)
|
||||
|
||||
return &PlanningResult{
|
||||
NextActions: actions,
|
||||
ActionSummary: actionSummary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractActionSummary extracts the summary from the actions
|
||||
func extractActionSummary(actions []ParsedAction) string {
|
||||
if len(actions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// use the Thought of the first action as summary
|
||||
if actions[0].Thought != "" {
|
||||
return actions[0].Thought
|
||||
}
|
||||
|
||||
// if no Thought, generate summary from action type
|
||||
action := actions[0]
|
||||
switch action.ActionType {
|
||||
case "click":
|
||||
return "点击操作"
|
||||
case "drag":
|
||||
return "拖拽操作"
|
||||
case "left_double":
|
||||
return "双击操作"
|
||||
case "right_single":
|
||||
return "右键点击操作"
|
||||
case "scroll":
|
||||
direction, _ := action.ActionInputs["direction"].(string)
|
||||
return fmt.Sprintf("滚动操作 (%s)", direction)
|
||||
case "type":
|
||||
content, _ := action.ActionInputs["content"].(string)
|
||||
if len(content) > 20 {
|
||||
content = content[:20] + "..."
|
||||
}
|
||||
return fmt.Sprintf("输入文本: %s", content)
|
||||
case "hotkey":
|
||||
key, _ := action.ActionInputs["key"].(string)
|
||||
return fmt.Sprintf("快捷键: %s", key)
|
||||
case "wait":
|
||||
return "等待操作"
|
||||
case "finished":
|
||||
return "完成操作"
|
||||
case "call_user":
|
||||
return "请求用户协助"
|
||||
default:
|
||||
return fmt.Sprintf("执行 %s 操作", action.ActionType)
|
||||
}
|
||||
}
|
||||
|
||||
// validateCoordinateAction 验证坐标类动作
|
||||
func validateCoordinateAction(action *ParsedAction, boxField string) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
// validateScrollDirection 验证滚动方向
|
||||
func validateScrollDirection(action *ParsedAction) {
|
||||
if direction, ok := action.ActionInputs["direction"].(string); !ok || direction == "" {
|
||||
// default to down
|
||||
action.ActionInputs["direction"] = "down"
|
||||
} else {
|
||||
switch strings.ToLower(direction) {
|
||||
case "up", "down", "left", "right":
|
||||
// keep original direction
|
||||
default:
|
||||
action.ActionInputs["direction"] = "down"
|
||||
log.Warn().Str("direction", direction).Msg("invalid scroll direction, set to default")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateTypeContent 验证输入文本内容
|
||||
func validateTypeContent(action *ParsedAction) {
|
||||
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
|
||||
// default to empty string
|
||||
action.ActionInputs["content"] = ""
|
||||
log.Warn().Msg("type action missing content parameter, set to default")
|
||||
}
|
||||
}
|
||||
|
||||
// validateHotkeyAction 验证快捷键动作
|
||||
func validateHotkeyAction(action *ParsedAction) {
|
||||
if key, ok := action.ActionInputs["key"]; !ok || key == "" {
|
||||
// 为空或缺失的键设置默认值
|
||||
action.ActionInputs["key"] = "Enter"
|
||||
log.Printf("警告: hotkey动作缺少key参数, 已设置默认值")
|
||||
}
|
||||
}
|
||||
|
||||
// SavePositionImg saves an image with position markers
|
||||
func SavePositionImg(params struct {
|
||||
InputImgBase64 string
|
||||
Rect struct {
|
||||
X float64
|
||||
Y float64
|
||||
}
|
||||
OutputPath string
|
||||
}) error {
|
||||
// 解码Base64图像
|
||||
imgData := params.InputImgBase64
|
||||
// 如果包含了数据URL前缀,去掉它
|
||||
if strings.HasPrefix(imgData, "data:image/") {
|
||||
parts := strings.Split(imgData, ",")
|
||||
if len(parts) > 1 {
|
||||
imgData = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 解码Base64
|
||||
unbased, err := base64.StdEncoding.DecodeString(imgData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法解码Base64图像: %w", err)
|
||||
}
|
||||
|
||||
// 解码图像
|
||||
reader := bytes.NewReader(unbased)
|
||||
img, _, err := image.Decode(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法解码图像数据: %w", err)
|
||||
}
|
||||
|
||||
// 创建一个可以在其上绘制的图像
|
||||
bounds := img.Bounds()
|
||||
rgba := image.NewRGBA(bounds)
|
||||
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
|
||||
|
||||
// 在点击/拖动位置绘制标记
|
||||
markRadius := 10
|
||||
x, y := int(params.Rect.X), int(params.Rect.Y)
|
||||
|
||||
// 绘制红色圆圈
|
||||
for i := -markRadius; i <= markRadius; i++ {
|
||||
for j := -markRadius; j <= markRadius; j++ {
|
||||
if i*i+j*j <= markRadius*markRadius {
|
||||
if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y {
|
||||
rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存图像
|
||||
outFile, err := os.Create(params.OutputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法创建输出文件: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// 编码为PNG并保存
|
||||
if err := png.Encode(outFile, rgba); err != nil {
|
||||
return fmt.Errorf("无法编码和保存图像: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadImage loads image and returns base64 encoded string
|
||||
func loadImage(imagePath string) (base64Str string, err error) {
|
||||
imageData, err := os.ReadFile(imagePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
base64Str = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||
return
|
||||
}
|
||||
277
uixt/ai/planner_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVLMPlanning(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/llk_3.jpg")
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
|
||||
1. 游戏目标: 玩家需要在规定时间内,通过连接相同的图案或图标,将它们从游戏界面中消除。
|
||||
2. 连接规则:
|
||||
- 两个相同的图案可以通过不超过三条直线连接。
|
||||
- 连接线可以水平或垂直,但不能穿过其他图案。
|
||||
- 连接线的转折次数不能超过两次。
|
||||
3. 游戏界面: 游戏界面通常是一个矩形区域,内含多个图案或图标,排列成行和列。
|
||||
4. 时间限制: 游戏通常设有时间限制,玩家需要在时间耗尽前完成所有图案的消除。
|
||||
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
|
||||
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。`
|
||||
|
||||
// userInstruction += "\n\n请基于以上游戏规则,请先点击第一个图标"
|
||||
|
||||
userInstruction += "\n\n点击[3排1列]果酱图案"
|
||||
|
||||
planner, err := NewPlanner(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := &PlanningOptions{
|
||||
UserInstruction: userInstruction,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 执行规划
|
||||
result, err := planner.Call(opts)
|
||||
|
||||
// 验证结果
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotEmpty(t, result.NextActions)
|
||||
|
||||
// 验证动作
|
||||
action := result.NextActions[0]
|
||||
assert.NotEmpty(t, action.ActionType)
|
||||
assert.NotEmpty(t, action.Thought)
|
||||
|
||||
// 根据动作类型验证参数
|
||||
switch action.ActionType {
|
||||
case ActionTypeClick:
|
||||
// 这些动作需要验证坐标
|
||||
assert.NotEmpty(t, action.ActionInputs["startBox"])
|
||||
|
||||
// 验证坐标格式
|
||||
coords, ok := action.ActionInputs["startBox"].([]float64)
|
||||
require.True(t, ok)
|
||||
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
|
||||
|
||||
// 验证坐标范围
|
||||
for _, coord := range coords {
|
||||
assert.GreaterOrEqual(t, coord, float64(0))
|
||||
}
|
||||
|
||||
case "wait", "finished", "call_user":
|
||||
// 这些动作不需要额外参数
|
||||
|
||||
default:
|
||||
t.Fatalf("未知的动作类型: %s", action.ActionType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXHSPlanning(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := `点击第二个帖子的作者头像`
|
||||
|
||||
planner, err := NewPlanner(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := &PlanningOptions{
|
||||
UserInstruction: userInstruction,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 执行规划
|
||||
result, err := planner.Call(opts)
|
||||
|
||||
// 验证结果
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotEmpty(t, result.NextActions)
|
||||
|
||||
// 验证动作
|
||||
action := result.NextActions[0]
|
||||
assert.NotEmpty(t, action.ActionType)
|
||||
assert.NotEmpty(t, action.Thought)
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *PlanningOptions
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid input",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击继续使用按钮",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "empty instruction",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
Content: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: ErrEmptyInstruction,
|
||||
},
|
||||
{
|
||||
name: "empty conversation history",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击立即卸载按钮",
|
||||
ConversationHistory: []*schema.Message{},
|
||||
},
|
||||
wantErr: ErrNoConversationHistory,
|
||||
},
|
||||
{
|
||||
name: "invalid image data",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击继续使用按钮",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
Content: "no image",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: ErrInvalidImageData,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateInput(tt.opts)
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessVLMResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
actions []ParsedAction
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
actions: []ParsedAction{
|
||||
{
|
||||
ActionType: "click",
|
||||
ActionInputs: map[string]interface{}{
|
||||
"startBox": []float64{0.5, 0.5},
|
||||
},
|
||||
Thought: "点击中心位置",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty actions",
|
||||
actions: []ParsedAction{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processVLMResponse(tt.actions)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tt.actions, result.NextActions)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavePositionImg(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
params := struct {
|
||||
InputImgBase64 string
|
||||
Rect struct {
|
||||
X float64
|
||||
Y float64
|
||||
}
|
||||
OutputPath string
|
||||
}{
|
||||
InputImgBase64: imageBase64,
|
||||
Rect: struct {
|
||||
X float64
|
||||
Y float64
|
||||
}{
|
||||
X: 100,
|
||||
Y: 100,
|
||||
},
|
||||
OutputPath: "testdata/output.png",
|
||||
}
|
||||
|
||||
err = SavePositionImg(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
// cleanup
|
||||
defer os.Remove(params.OutputPath)
|
||||
}
|
||||
27
uixt/ai/prompt-ui-tars.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package ai
|
||||
|
||||
const uiTarsPlanningPrompt = `
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
Thought: ...
|
||||
Action: ...
|
||||
|
||||
## Action Space
|
||||
click(start_box='[x1, y1, x2, y2]')
|
||||
left_double(start_box='[x1, y1, x2, y2]')
|
||||
right_single(start_box='[x1, y1, x2, y2]')
|
||||
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\n" at the end of content.
|
||||
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
|
||||
## Note
|
||||
- Use Chinese in Thought part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
|
||||
|
||||
## User Instruction
|
||||
`
|
||||
BIN
uixt/ai/testdata/1.jpeg
vendored
Normal file
|
After Width: | Height: | Size: 393 KiB |
BIN
uixt/ai/testdata/2.jpeg
vendored
Normal file
|
After Width: | Height: | Size: 390 KiB |
BIN
uixt/ai/testdata/llk_1.png
vendored
Normal file
|
After Width: | Height: | Size: 437 KiB |
BIN
uixt/ai/testdata/llk_2.png
vendored
Normal file
|
After Width: | Height: | Size: 407 KiB |
BIN
uixt/ai/testdata/llk_3.jpg
vendored
Normal file
|
After Width: | Height: | Size: 123 KiB |
BIN
uixt/ai/testdata/popup_risk_warning.png
vendored
Normal file
|
After Width: | Height: | Size: 1.3 MiB |
BIN
uixt/ai/testdata/xhs-feed.jpeg
vendored
Normal file
|
After Width: | Height: | Size: 649 KiB |
@@ -24,7 +24,9 @@ func setupADBDriverExt(t *testing.T) *XTDriver {
|
||||
driver, err := device.NewDriver()
|
||||
require.Nil(t, err)
|
||||
return NewXTDriver(driver,
|
||||
ai.WithCVService(ai.CVServiceTypeVEDEM))
|
||||
ai.WithCVService(ai.CVServiceTypeVEDEM),
|
||||
ai.WithLLMService(ai.LLMServiceTypeUITARS),
|
||||
)
|
||||
}
|
||||
|
||||
func setupUIA2DriverExt(t *testing.T) *XTDriver {
|
||||
|
||||
42
uixt/driver_ext_ai.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package uixt
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/uixt/ai"
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
|
||||
if dExt.LLMService == nil {
|
||||
return nil, errors.New("LLM service is not initialized")
|
||||
}
|
||||
|
||||
screenShotBase64, err := dExt.GetScreenShotBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
planningOpts := &ai.PlanningOptions{
|
||||
UserInstruction: text,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: screenShotBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := dExt.LLMService.Call(planningOpts)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get next action from planner")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -8,6 +8,22 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (dExt *XTDriver) TapByLLM(text string, opts ...option.ActionOption) error {
|
||||
text = "[click] " + text
|
||||
result, err := dExt.PlanNextAction(text, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action := result.NextActions[0]
|
||||
if action.ActionType != ai.ActionTypeClick {
|
||||
return fmt.Errorf("expected click action, got: %s", action.ActionType)
|
||||
}
|
||||
|
||||
point := action.ActionInputs["startBox"].([]float64)
|
||||
return dExt.TapAbsXY(point[0], point[1], opts...)
|
||||
}
|
||||
|
||||
func (dExt *XTDriver) TapByOCR(text string, opts ...option.ActionOption) error {
|
||||
actionOptions := option.NewActionOptions(opts...)
|
||||
if actionOptions.ScreenShotFileName == "" {
|
||||
|
||||
@@ -123,6 +123,19 @@ func TestDriverExt_TapByOCR(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestDriverExt_TapByLLM(t *testing.T) {
|
||||
driver := setupDriverExt(t)
|
||||
err := driver.TapByLLM("点击第一个帖子的作者头像")
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestDriverExt_PlanNextAction(t *testing.T) {
|
||||
driver := setupDriverExt(t)
|
||||
result, err := driver.PlanNextAction("启动抖音")
|
||||
assert.Nil(t, err)
|
||||
t.Log(result)
|
||||
}
|
||||
|
||||
func TestDriverExt_prepareSwipeAction(t *testing.T) {
|
||||
driver := setupDriverExt(t)
|
||||
|
||||
|
||||