refactor: ui-tars planner

This commit is contained in:
lilong.129
2025-03-19 17:48:13 +08:00
parent 6c74727c44
commit 7da305f577
8 changed files with 201 additions and 191 deletions

View File

@@ -10,40 +10,33 @@ import (
)
// NewActionParser creates a new ActionParser instance
func NewActionParser(prediction string, factor float64) *ActionParser {
func NewActionParser(factor float64) *ActionParser {
return &ActionParser{
Prediction: prediction,
Factor: factor,
Factor: factor,
}
}
// ActionParser parses VLM responses and converts them to structured actions
type ActionParser struct {
Prediction string
Factor float64
Factor float64 // TODO
}
// Parse parses the prediction text and extracts actions
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
// try parsing JSON format
// try parsing JSON format, from VLM like GPT-4o
var jsonActions []ParsedAction
jsonActions, jsonErr := p.parseJSON(predictionText)
if jsonErr == nil && len(jsonActions) > 0 {
if jsonErr == nil {
return jsonActions, nil
}
// if JSON parsing fails, try parsing Thought/Action format
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
if thoughtErr == nil && len(thoughtActions) > 0 {
if thoughtErr == nil {
return thoughtActions, nil
}
// both parsing methods failed
if jsonErr != nil && thoughtErr != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v; %v", jsonErr, thoughtErr)
}
return nil, fmt.Errorf("no actions returned from VLM")
return nil, fmt.Errorf("no valid actions returned from VLM, jsonErr: %v, thoughtErr: %v", jsonErr, thoughtErr)
}
// parseJSON tries to parse the response as JSON format
@@ -92,7 +85,7 @@ func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction
thought = strings.TrimSpace(thoughtMatch[1])
}
// extract Action part
// extract Action part, e.g. "click(start_box='(552,454)')"
actionMatch := actionRegex.FindStringSubmatch(predictionText)
if len(actionMatch) < 2 {
return nil, fmt.Errorf("no action found in the response")
@@ -125,6 +118,7 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
"call_user": regexp.MustCompile(`call_user\(\)`),
}
parsedActions := make([]ParsedAction, 0)
for actionType, regex := range actionRegexes {
matches := regex.FindStringSubmatch(actionText)
if len(matches) == 0 {
@@ -183,10 +177,13 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
// 这些动作没有额外参数
}
return []ParsedAction{action}, nil
parsedActions = append(parsedActions, action)
}
return nil, fmt.Errorf("unknown action format: %s", actionText)
if len(parsedActions) == 0 {
return nil, fmt.Errorf("no valid actions returned from VLM")
}
return parsedActions, nil
}
// normalizeAction normalizes the coordinates in the action
@@ -215,16 +212,14 @@ func (p *ActionParser) normalizeAction(action *ParsedAction) error {
}
// normalizeCoordinates normalizes the coordinates based on the factor
func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
var coords []float64
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
// check empty string
if coordStr == "" {
return "", fmt.Errorf("empty coordinate string")
return nil, fmt.Errorf("empty coordinate string")
}
if !strings.Contains(coordStr, ",") {
return "", fmt.Errorf("invalid coordinate string: %s", coordStr)
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
}
// remove possible brackets and split coordinates
@@ -236,15 +231,9 @@ func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
jsonStr = "[" + coordStr + "]"
}
err := json.Unmarshal([]byte(jsonStr), &coords)
err = json.Unmarshal([]byte(jsonStr), &coords)
if err != nil {
return "", fmt.Errorf("failed to parse coordinate string: %w", err)
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
}
normalized, err := json.Marshal(coords)
if err != nil {
return "", fmt.Errorf("failed to marshal normalized coordinates: %w", err)
}
return string(normalized), nil
return coords, nil
}