mirror of
https://github.com/httprunner/httprunner.git
synced 2026-06-28 02:51:42 +08:00
refactor: merge ai parser
This commit is contained in:
@@ -1 +1 @@
|
||||
v5.0.0-beta-2505232205
|
||||
v5.0.0-beta-2505240025
|
||||
|
||||
@@ -2,8 +2,6 @@ package ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/httprunner/httprunner/v5/internal/json"
|
||||
@@ -49,109 +47,48 @@ func (p *JSONContentParser) Parse(content string, size types.Size) (*PlanningRes
|
||||
}
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
var response PlanningResult
|
||||
if err := json.Unmarshal([]byte(content), &response); err != nil {
|
||||
// Define a temporary struct to parse the expected JSON format
|
||||
var jsonResponse struct {
|
||||
Actions []Action `json:"actions"`
|
||||
Summary string `json:"summary"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &jsonResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
|
||||
}
|
||||
|
||||
if response.Error != "" {
|
||||
return nil, errors.New(response.Error)
|
||||
if jsonResponse.Error != "" {
|
||||
return nil, errors.New(jsonResponse.Error)
|
||||
}
|
||||
|
||||
if len(response.Actions) == 0 {
|
||||
if len(jsonResponse.Actions) == 0 {
|
||||
return nil, errors.New("no actions returned from VLM")
|
||||
}
|
||||
|
||||
// normalize actions
|
||||
// normalize actions using unified function from ui-tars parser
|
||||
var normalizedActions []Action
|
||||
for i := range response.Actions {
|
||||
for i := range jsonResponse.Actions {
|
||||
// create a new variable, avoid implicit memory aliasing in for loop.
|
||||
action := response.Actions[i]
|
||||
if err := normalizeAction(&action); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to normalize action")
|
||||
action := jsonResponse.Actions[i]
|
||||
|
||||
// Process and normalize arguments (from JSON parser)
|
||||
processedArgs, err := processActionArguments(action.ActionInputs, size)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to process action arguments")
|
||||
}
|
||||
action.ActionInputs = processedArgs
|
||||
|
||||
normalizedActions = append(normalizedActions, action)
|
||||
}
|
||||
|
||||
// Convert actions to tool calls using function from parser_ui_tars.go
|
||||
toolCalls := convertActionsToToolCalls(normalizedActions)
|
||||
|
||||
return &PlanningResult{
|
||||
Actions: normalizedActions,
|
||||
ActionSummary: response.ActionSummary,
|
||||
ToolCalls: toolCalls,
|
||||
ActionSummary: jsonResponse.Summary,
|
||||
Thought: jsonResponse.Summary,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeAction normalizes the coordinates in the action
|
||||
func normalizeAction(action *Action) error {
|
||||
switch action.ActionType {
|
||||
case "click", "drag":
|
||||
// handle click and drag action coordinates
|
||||
if startBox, ok := action.ActionInputs["startBox"].(string); ok {
|
||||
normalized, err := normalizeCoordinates(startBox)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize startBox: %w", err)
|
||||
}
|
||||
action.ActionInputs["startBox"] = normalized
|
||||
}
|
||||
|
||||
if endBox, ok := action.ActionInputs["endBox"].(string); ok {
|
||||
normalized, err := normalizeCoordinates(endBox)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize endBox: %w", err)
|
||||
}
|
||||
action.ActionInputs["endBox"] = normalized
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeCoordinates normalizes the coordinates based on the factor
|
||||
func normalizeCoordinates(coordStr string) (coords []float64, err error) {
|
||||
// check empty string
|
||||
if coordStr == "" {
|
||||
return nil, fmt.Errorf("empty coordinate string")
|
||||
}
|
||||
|
||||
// handle BBox format: <bbox>x1 y1 x2 y2</bbox>
|
||||
bboxRegex := regexp.MustCompile(`<bbox>(\d+\s+\d+\s+\d+\s+\d+)</bbox>`)
|
||||
bboxMatches := bboxRegex.FindStringSubmatch(coordStr)
|
||||
if len(bboxMatches) > 1 {
|
||||
// Extract space-separated values from inside the bbox tags
|
||||
bboxContent := bboxMatches[1]
|
||||
// Split by whitespace
|
||||
parts := strings.Fields(bboxContent)
|
||||
if len(parts) == 4 {
|
||||
coords = make([]float64, 4)
|
||||
for i, part := range parts {
|
||||
val, e := strconv.ParseFloat(part, 64)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("failed to parse coordinate value '%s': %w", part, e)
|
||||
}
|
||||
coords[i] = val
|
||||
}
|
||||
// 将 val 转换为 [x,y] 坐标
|
||||
x := (coords[0] + coords[2]) / 2
|
||||
y := (coords[1] + coords[3]) / 2
|
||||
return []float64{x, y}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// handle coordinate string, e.g. "[100, 200]", "(100, 200)"
|
||||
if strings.Contains(coordStr, ",") {
|
||||
// remove possible brackets and split coordinates
|
||||
coordStr = strings.Trim(coordStr, "[]() \t")
|
||||
|
||||
// try parsing JSON array
|
||||
jsonStr := coordStr
|
||||
if !strings.HasPrefix(jsonStr, "[") {
|
||||
jsonStr = "[" + coordStr + "]"
|
||||
}
|
||||
|
||||
err = json.Unmarshal([]byte(jsonStr), &coords)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
|
||||
}
|
||||
return coords, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid coordinate string format: %s", coordStr)
|
||||
}
|
||||
|
||||
@@ -3,41 +3,42 @@ package ai
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/httprunner/httprunner/v5/internal/json"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseAction(t *testing.T) {
|
||||
actionStr := "click(point='<point>200 300</point>')"
|
||||
result, err := ParseAction(actionStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, result.Function, "click")
|
||||
assert.Equal(t, result.Args["point"], "<point>200 300</point>")
|
||||
}
|
||||
|
||||
func TestParseActionToStructureOutput(t *testing.T) {
|
||||
text := "Thought: test\nAction: click(point='<point>200 300</point>')"
|
||||
parser := &UITARSContentParser{}
|
||||
result, err := parser.Parse(text, types.Size{Height: 224, Width: 224})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, result.Actions[0].ActionType, "click")
|
||||
assert.Contains(t, result.Actions[0].ActionInputs, "start_box")
|
||||
function := result.ToolCalls[0].Function
|
||||
assert.Equal(t, function.Name, "click")
|
||||
assert.Contains(t, function.Arguments, "start_box")
|
||||
|
||||
text = "Thought: 我看到页面上有几个帖子,第二个帖子的标题是\"字节四年,头发白了\"。要完成任务,我需要点击这个帖子下方的作者头像,这样就能进入作者的个人主页了。\nAction: click(start_point='<point>550 450 550 450</point>')"
|
||||
result, err = parser.Parse(text, types.Size{Height: 2341, Width: 1024})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, result.Actions[0].ActionType, "click")
|
||||
assert.Contains(t, result.Actions[0].ActionInputs, "start_box")
|
||||
function = result.ToolCalls[0].Function
|
||||
assert.Equal(t, function.Name, "click")
|
||||
assert.Contains(t, function.Arguments, "start_box")
|
||||
|
||||
// Test new bracket format
|
||||
text = "Thought: 我需要点击这个按钮\nAction: click(start_box='[100, 200, 150, 250]')"
|
||||
result, err = parser.Parse(text, types.Size{Height: 1000, Width: 1000})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, result.Actions[0].ActionType, "click")
|
||||
assert.Contains(t, result.Actions[0].ActionInputs, "start_box")
|
||||
coords := result.Actions[0].ActionInputs["start_box"].([]float64)
|
||||
function = result.ToolCalls[0].Function
|
||||
assert.Equal(t, function.Name, "click")
|
||||
assert.Contains(t, function.Arguments, "start_box")
|
||||
arguments := make(map[string]interface{})
|
||||
err = json.Unmarshal([]byte(function.Arguments), &arguments)
|
||||
assert.Nil(t, err)
|
||||
coordsInterface := arguments["start_box"].([]interface{})
|
||||
coords := make([]float64, len(coordsInterface))
|
||||
for i, v := range coordsInterface {
|
||||
coords[i] = v.(float64)
|
||||
}
|
||||
assert.Equal(t, 4, len(coords))
|
||||
assert.Equal(t, 100.0, coords[0])
|
||||
assert.Equal(t, 200.0, coords[1])
|
||||
@@ -48,13 +49,608 @@ func TestParseActionToStructureOutput(t *testing.T) {
|
||||
text = "Thought: 我需要拖拽元素\nAction: drag(start_box='[100, 200, 150, 250]', end_box='[300, 400, 350, 450]')"
|
||||
result, err = parser.Parse(text, types.Size{Height: 1000, Width: 1000})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, result.Actions[0].ActionType, "drag")
|
||||
assert.Contains(t, result.Actions[0].ActionInputs, "start_box")
|
||||
assert.Contains(t, result.Actions[0].ActionInputs, "end_box")
|
||||
startCoords := result.Actions[0].ActionInputs["start_box"].([]float64)
|
||||
endCoords := result.Actions[0].ActionInputs["end_box"].([]float64)
|
||||
function = result.ToolCalls[0].Function
|
||||
assert.Equal(t, function.Name, "drag")
|
||||
assert.Contains(t, function.Arguments, "start_box")
|
||||
assert.Contains(t, function.Arguments, "end_box")
|
||||
arguments = make(map[string]interface{})
|
||||
err = json.Unmarshal([]byte(function.Arguments), &arguments)
|
||||
assert.Nil(t, err)
|
||||
startCoordsInterface := arguments["start_box"].([]interface{})
|
||||
endCoordsInterface := arguments["end_box"].([]interface{})
|
||||
startCoords := make([]float64, len(startCoordsInterface))
|
||||
endCoords := make([]float64, len(endCoordsInterface))
|
||||
for i, v := range startCoordsInterface {
|
||||
startCoords[i] = v.(float64)
|
||||
}
|
||||
for i, v := range endCoordsInterface {
|
||||
endCoords[i] = v.(float64)
|
||||
}
|
||||
assert.Equal(t, 4, len(startCoords))
|
||||
assert.Equal(t, 4, len(endCoords))
|
||||
assert.Equal(t, 100.0, startCoords[0])
|
||||
assert.Equal(t, 300.0, endCoords[0])
|
||||
}
|
||||
|
||||
// Test normalizeCoordinatesFormat function
|
||||
func TestNormalizeCoordinatesFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "point tag with 2 numbers",
|
||||
input: "<point>100 200</point>",
|
||||
expected: "(100,200)",
|
||||
},
|
||||
{
|
||||
name: "point tag with 4 numbers",
|
||||
input: "<point>100 200 150 250</point>",
|
||||
expected: "(100,200,150,250)",
|
||||
},
|
||||
{
|
||||
name: "bbox tag",
|
||||
input: "<bbox>100 200 150 250</bbox>",
|
||||
expected: "(100,200,150,250)",
|
||||
},
|
||||
{
|
||||
name: "bracket format",
|
||||
input: "[100, 200, 150, 250]",
|
||||
expected: "(100,200,150,250)",
|
||||
},
|
||||
{
|
||||
name: "bracket format with spaces",
|
||||
input: "[100, 200, 150, 250]",
|
||||
expected: "(100,200,150,250)",
|
||||
},
|
||||
{
|
||||
name: "multiple point tags",
|
||||
input: "<point>100 200</point> and <point>300 400</point>",
|
||||
expected: "(100,200) and (300,400)",
|
||||
},
|
||||
{
|
||||
name: "no coordinates",
|
||||
input: "click on button",
|
||||
expected: "click on button",
|
||||
},
|
||||
{
|
||||
name: "mixed formats",
|
||||
input: "<point>100 200</point> and [300, 400, 350, 450]",
|
||||
expected: "(100,200) and (300,400,350,450)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeCoordinatesFormat(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test convertRelativeToAbsolute function
|
||||
func TestConvertRelativeToAbsolute(t *testing.T) {
|
||||
size := types.Size{Width: 1000, Height: 2000}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
relativeCoord float64
|
||||
isXCoord bool
|
||||
expectedResult float64
|
||||
}{
|
||||
{
|
||||
name: "x coordinate conversion",
|
||||
relativeCoord: 500, // 500/1000 * 1000 = 500
|
||||
isXCoord: true,
|
||||
expectedResult: 500.0,
|
||||
},
|
||||
{
|
||||
name: "y coordinate conversion",
|
||||
relativeCoord: 500, // 500/1000 * 2000 = 1000
|
||||
isXCoord: false,
|
||||
expectedResult: 1000.0,
|
||||
},
|
||||
{
|
||||
name: "x coordinate with rounding",
|
||||
relativeCoord: 333, // 333/1000 * 1000 = 333
|
||||
isXCoord: true,
|
||||
expectedResult: 333.0,
|
||||
},
|
||||
{
|
||||
name: "y coordinate with rounding",
|
||||
relativeCoord: 750, // 750/1000 * 2000 = 1500
|
||||
isXCoord: false,
|
||||
expectedResult: 1500.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertRelativeToAbsolute(tt.relativeCoord, tt.isXCoord, size)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test parseActionTypeAndArguments function
|
||||
func TestParseActionTypeAndArguments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
actionStr string
|
||||
expectedType string
|
||||
expectedArgs map[string]interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "simple click action",
|
||||
actionStr: "click(start_box='100,200,150,250')",
|
||||
expectedType: "click",
|
||||
expectedArgs: map[string]interface{}{
|
||||
"start_box": "100,200,150,250",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "drag action with two parameters",
|
||||
actionStr: "drag(start_box='100,200,150,250', end_box='300,400,350,450')",
|
||||
expectedType: "drag",
|
||||
expectedArgs: map[string]interface{}{
|
||||
"start_box": "100,200,150,250",
|
||||
"end_box": "300,400,350,450",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "parameter name mapping - start_point to start_box",
|
||||
actionStr: "click(start_point='100,200,150,250')",
|
||||
expectedType: "click",
|
||||
expectedArgs: map[string]interface{}{
|
||||
"start_box": "100,200,150,250", // should be mapped from start_point
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "parameter name mapping - point to start_box",
|
||||
actionStr: "click(point='100,200')",
|
||||
expectedType: "click",
|
||||
expectedArgs: map[string]interface{}{
|
||||
"start_box": "100,200", // should be mapped from point
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "type action with content",
|
||||
actionStr: "type(content='Hello World')",
|
||||
expectedType: "type",
|
||||
expectedArgs: map[string]interface{}{
|
||||
"content": "Hello World",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "action without parameters",
|
||||
actionStr: "press_home()",
|
||||
expectedType: "press_home",
|
||||
expectedArgs: map[string]interface{}{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - no parentheses",
|
||||
actionStr: "click",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid format - missing closing parenthesis",
|
||||
actionStr: "click(start_box='100,200'",
|
||||
expectedType: "click",
|
||||
expectedArgs: map[string]interface{}{
|
||||
"start_box": "100,200", // 正则表达式能够匹配到这个参数
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actionType, rawArgs, err := parseActionTypeAndArguments(tt.actionStr)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedType, actionType)
|
||||
assert.Equal(t, tt.expectedArgs, rawArgs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test normalizeParameterName function
|
||||
func TestNormalizeParameterName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "start_point to start_box",
|
||||
input: "start_point",
|
||||
expected: "start_box",
|
||||
},
|
||||
{
|
||||
name: "end_point to end_box",
|
||||
input: "end_point",
|
||||
expected: "end_box",
|
||||
},
|
||||
{
|
||||
name: "point to start_box",
|
||||
input: "point",
|
||||
expected: "start_box",
|
||||
},
|
||||
{
|
||||
name: "unchanged parameter",
|
||||
input: "content",
|
||||
expected: "content",
|
||||
},
|
||||
{
|
||||
name: "unchanged parameter - direction",
|
||||
input: "direction",
|
||||
expected: "direction",
|
||||
},
|
||||
{
|
||||
name: "unchanged parameter - start_box",
|
||||
input: "start_box",
|
||||
expected: "start_box",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeParameterName(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isCoordinateParameter function
|
||||
func TestIsCoordinateParameter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
paramName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "start_box is coordinate",
|
||||
paramName: "start_box",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "end_box is coordinate",
|
||||
paramName: "end_box",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "start_point is coordinate",
|
||||
paramName: "start_point",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "end_point is coordinate",
|
||||
paramName: "end_point",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "content is not coordinate",
|
||||
paramName: "content",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "direction is not coordinate",
|
||||
paramName: "direction",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "key is not coordinate",
|
||||
paramName: "key",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCoordinateParameter(tt.paramName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test normalizeStringParam function
|
||||
func TestNormalizeStringParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
paramName string
|
||||
paramValue interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "content with escape characters",
|
||||
paramName: "content",
|
||||
paramValue: "Hello\\nWorld\\\"Test\\'",
|
||||
expected: "Hello\nWorld\"Test'",
|
||||
},
|
||||
{
|
||||
name: "content without escape characters",
|
||||
paramName: "content",
|
||||
paramValue: "Hello World",
|
||||
expected: "Hello World",
|
||||
},
|
||||
{
|
||||
name: "non-content parameter with escape characters",
|
||||
paramName: "direction",
|
||||
paramValue: "down\\nup",
|
||||
expected: "down\\nup", // should not process escape chars
|
||||
},
|
||||
{
|
||||
name: "string with leading/trailing spaces",
|
||||
paramName: "content",
|
||||
paramValue: " Hello World ",
|
||||
expected: "Hello World",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
paramName: "content",
|
||||
paramValue: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "nil value",
|
||||
paramName: "content",
|
||||
paramValue: nil,
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "non-string value",
|
||||
paramName: "content",
|
||||
paramValue: 123,
|
||||
expected: 123,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeStringParam(tt.paramName, tt.paramValue)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test normalizeStringCoordinates function
|
||||
func TestNormalizeStringCoordinates(t *testing.T) {
|
||||
size := types.Size{Width: 1000, Height: 1000}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
coordStr string
|
||||
expected []float64
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "simple coordinate string",
|
||||
coordStr: "100,200,150,250",
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "coordinate string with spaces",
|
||||
coordStr: " 100 , 200 , 150 , 250 ",
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "point tag format",
|
||||
coordStr: "<point>100 200</point>",
|
||||
expected: []float64{100.0, 200.0},
|
||||
},
|
||||
{
|
||||
name: "bbox tag format",
|
||||
coordStr: "<bbox>100 200 150 250</bbox>",
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "bracket format",
|
||||
coordStr: "[100, 200, 150, 250]",
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
coordStr: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid coordinate string",
|
||||
coordStr: "abc,def",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "insufficient coordinates",
|
||||
coordStr: "100",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := normalizeStringCoordinates(tt.coordStr, size)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.expected), len(result))
|
||||
for i, expected := range tt.expected {
|
||||
assert.Equal(t, expected, result[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test normalizeActionCoordinates function
|
||||
func TestNormalizeActionCoordinates(t *testing.T) {
|
||||
size := types.Size{Width: 1000, Height: 1000}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
coordData interface{}
|
||||
expected []float64
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "JSON array format - []interface{}",
|
||||
coordData: []interface{}{100.0, 200.0, 150.0, 250.0},
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "JSON array format with int values",
|
||||
coordData: []interface{}{100, 200, 150, 250},
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "float64 slice format",
|
||||
coordData: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "string format",
|
||||
coordData: "100,200,150,250",
|
||||
expected: []float64{100.0, 200.0, 150.0, 250.0},
|
||||
},
|
||||
{
|
||||
name: "two-element coordinate",
|
||||
coordData: []interface{}{100.0, 200.0},
|
||||
expected: []float64{100.0, 200.0},
|
||||
},
|
||||
{
|
||||
name: "insufficient elements in array",
|
||||
coordData: []interface{}{100.0},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid array element type",
|
||||
coordData: []interface{}{"abc", 200.0},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported coordinate format",
|
||||
coordData: map[string]interface{}{"x": 100, "y": 200},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := normalizeActionCoordinates(tt.coordData, size)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.expected), len(result))
|
||||
for i, expected := range tt.expected {
|
||||
assert.Equal(t, expected, result[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test processActionArguments function
|
||||
func TestProcessActionArguments(t *testing.T) {
|
||||
size := types.Size{Width: 1000, Height: 1000}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rawArgs map[string]interface{}
|
||||
expected map[string]interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "coordinate and non-coordinate parameters",
|
||||
rawArgs: map[string]interface{}{
|
||||
"start_box": "100,200,150,250",
|
||||
"content": "Hello\\nWorld",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"start_box": []float64{100.0, 200.0, 150.0, 250.0},
|
||||
"content": "Hello\nWorld",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple coordinate parameters",
|
||||
rawArgs: map[string]interface{}{
|
||||
"start_box": "100,200,150,250",
|
||||
"end_box": "300,400,350,450",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"start_box": []float64{100.0, 200.0, 150.0, 250.0},
|
||||
"end_box": []float64{300.0, 400.0, 350.0, 450.0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only non-coordinate parameters",
|
||||
rawArgs: map[string]interface{}{
|
||||
"content": "Hello World",
|
||||
"direction": "down",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"content": "Hello World",
|
||||
"direction": "down",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty arguments",
|
||||
rawArgs: map[string]interface{}{},
|
||||
expected: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "invalid coordinate parameter",
|
||||
rawArgs: map[string]interface{}{
|
||||
"start_box": "invalid",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processActionArguments(tt.rawArgs, size)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.expected), len(result))
|
||||
|
||||
for key, expectedValue := range tt.expected {
|
||||
actualValue, exists := result[key]
|
||||
assert.True(t, exists, "Key %s should exist in result", key)
|
||||
|
||||
// Handle slice comparison separately
|
||||
if expectedSlice, ok := expectedValue.([]float64); ok {
|
||||
actualSlice, ok := actualValue.([]float64)
|
||||
assert.True(t, ok, "Value for key %s should be []float64", key)
|
||||
assert.Equal(t, len(expectedSlice), len(actualSlice))
|
||||
for i, expected := range expectedSlice {
|
||||
assert.Equal(t, expected, actualSlice[i])
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// reference:
|
||||
// https://github.com/bytedance/UI-TARS/blob/main/codes/ui_tars/action_parser.py
|
||||
|
||||
const (
|
||||
DefaultFactor = 1000
|
||||
)
|
||||
@@ -32,35 +29,31 @@ func (p *UITARSContentParser) SystemPrompt() string {
|
||||
|
||||
// ParseActionToStructureOutput parses the model output text into structured actions.
|
||||
func (p *UITARSContentParser) Parse(content string, size types.Size) (*PlanningResult, error) {
|
||||
text := strings.TrimSpace(content)
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Extract thought/reflection
|
||||
thought := p.extractThought(text)
|
||||
// Extract thought string
|
||||
thought := p.extractThought(content)
|
||||
|
||||
// Normalize text first
|
||||
normalizedText := p.normalizeCoordinates(text)
|
||||
|
||||
// Get action string from normalized text
|
||||
actionStr, err := p.extractActionString(normalizedText)
|
||||
// Extract action string
|
||||
actionStr, err := p.extractActionString(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse actions directly
|
||||
// Parse and process actions
|
||||
actions, err := p.parseActionString(actionStr, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert actions to tool calls
|
||||
toolCalls := p.convertActionsToToolCalls(actions)
|
||||
toolCalls := convertActionsToToolCalls(actions)
|
||||
|
||||
return &PlanningResult{
|
||||
ToolCalls: toolCalls,
|
||||
Actions: actions,
|
||||
ActionSummary: thought,
|
||||
Thought: thought,
|
||||
Text: normalizedText,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -85,8 +78,31 @@ func (p *UITARSContentParser) extractActionString(text string) (string, error) {
|
||||
return "", fmt.Errorf("no Action: found")
|
||||
}
|
||||
|
||||
// normalizeCoordinates normalizes the text by converting points to coordinates and replacing keywords
|
||||
func (p *UITARSContentParser) normalizeCoordinates(text string) string {
|
||||
// parseActionString parse and process actions
|
||||
func (p *UITARSContentParser) parseActionString(actionStr string, size types.Size) ([]Action, error) {
|
||||
// Parse action type and raw arguments
|
||||
actionType, rawArgs, err := parseActionTypeAndArguments(actionStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process and normalize arguments
|
||||
processedArgs, err := processActionArguments(rawArgs, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create final action
|
||||
action := Action{
|
||||
ActionType: actionType,
|
||||
ActionInputs: processedArgs,
|
||||
}
|
||||
|
||||
return []Action{action}, nil
|
||||
}
|
||||
|
||||
// normalizeCoordinatesFormat standardizes coordinate format in text (without pixel conversion)
|
||||
func normalizeCoordinatesFormat(text string) string {
|
||||
// Convert point tags to coordinate format
|
||||
if strings.Contains(text, "<point>") {
|
||||
// support <point>x1 y1 x2 y2</point> or <point>x y</point>
|
||||
@@ -127,28 +143,32 @@ func (p *UITARSContentParser) normalizeCoordinates(text string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// Legacy parameter name replacements (keep for backward compatibility)
|
||||
text = strings.ReplaceAll(text, "start_point=", "start_box=")
|
||||
text = strings.ReplaceAll(text, "end_point=", "end_box=")
|
||||
text = strings.ReplaceAll(text, "point=", "start_box=")
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// parseActionString parses the action string directly
|
||||
func (p *UITARSContentParser) parseActionString(actionStr string, size types.Size) ([]Action, error) {
|
||||
actions := make([]Action, 0, 1)
|
||||
// convertRelativeToAbsolute converts relative coordinates to absolute pixel coordinates
|
||||
func convertRelativeToAbsolute(relativeCoord float64, isXCoord bool, size types.Size) float64 {
|
||||
if isXCoord {
|
||||
return math.Round((relativeCoord/DefaultFactor*float64(size.Width))*10) / 10
|
||||
}
|
||||
return math.Round((relativeCoord/DefaultFactor*float64(size.Height))*10) / 10
|
||||
}
|
||||
|
||||
// parseActionTypeAndArguments extracts function name and raw parameter map from action string
|
||||
// Input: "click(start_box='100,200,150,250')" or "click(start_point='100,200,150,250')"
|
||||
// Output: actionType="click", rawArgs={"start_box": "100,200,150,250"}
|
||||
func parseActionTypeAndArguments(actionStr string) (actionType string, rawArgs map[string]interface{}, err error) {
|
||||
// Parse action type and parameters
|
||||
actionParts := strings.SplitN(actionStr, "(", 2)
|
||||
if len(actionParts) < 2 {
|
||||
return nil, fmt.Errorf("not a function call")
|
||||
return "", nil, fmt.Errorf("not a function call")
|
||||
}
|
||||
|
||||
funcName := strings.TrimSpace(actionParts[0])
|
||||
actionType = strings.TrimSpace(actionParts[0])
|
||||
paramsText := strings.TrimSuffix(strings.TrimSpace(actionParts[1]), ")")
|
||||
|
||||
args := make(map[string]string)
|
||||
// Parse string parameters to map
|
||||
rawArgs = make(map[string]interface{})
|
||||
if paramsText != "" {
|
||||
// Use regex to extract key=value pairs, handling quoted values properly
|
||||
re := regexp.MustCompile(`(\w+)\s*=\s*['"]([^'"]*?)['"]`)
|
||||
@@ -157,76 +177,188 @@ func (p *UITARSContentParser) parseActionString(actionStr string, size types.Siz
|
||||
if len(match) >= 3 {
|
||||
key := strings.TrimSpace(match[1])
|
||||
value := strings.TrimSpace(match[2])
|
||||
args[key] = value
|
||||
|
||||
// Apply parameter name mapping (legacy compatibility)
|
||||
key = normalizeParameterName(key)
|
||||
rawArgs[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actionInputs, err := p.parseActionInputs(args, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
actions = append(actions, Action{
|
||||
ActionType: funcName,
|
||||
ActionInputs: actionInputs,
|
||||
})
|
||||
|
||||
return actions, nil
|
||||
return actionType, rawArgs, nil
|
||||
}
|
||||
|
||||
// parseActionInputs parses action parameters and converts coordinates
|
||||
func (p *UITARSContentParser) parseActionInputs(args map[string]string, size types.Size) (map[string]any, error) {
|
||||
actionInputs := make(map[string]any)
|
||||
imageWidth := size.Width
|
||||
imageHeight := size.Height
|
||||
// normalizeParameterName applies legacy parameter name mappings
|
||||
func normalizeParameterName(paramName string) string {
|
||||
switch paramName {
|
||||
case "start_point":
|
||||
return "start_box"
|
||||
case "end_point":
|
||||
return "end_box"
|
||||
case "point":
|
||||
return "start_box"
|
||||
default:
|
||||
return paramName
|
||||
}
|
||||
}
|
||||
|
||||
for paramName, param := range args {
|
||||
if param == "" {
|
||||
continue
|
||||
}
|
||||
param = strings.TrimSpace(param)
|
||||
// processActionArguments processes raw arguments based on action type and parameter types
|
||||
// Input: rawArgs={"start_box": "100,200,150,250"}
|
||||
// Output: processedArgs={"start_box": [120.5, 240.1, 180.7, 300.2]} (converted to pixels)
|
||||
func processActionArguments(rawArgs map[string]interface{}, size types.Size) (map[string]interface{}, error) {
|
||||
processedArgs := make(map[string]interface{})
|
||||
|
||||
// Convert box coordinates
|
||||
if strings.Contains(paramName, "box") || strings.Contains(paramName, "point") {
|
||||
// Extract numbers from the parameter value using regex
|
||||
re := regexp.MustCompile(`\d+`)
|
||||
numbers := re.FindAllString(param, -1)
|
||||
if len(numbers) >= 2 {
|
||||
coords := make([]float64, len(numbers))
|
||||
for i, numStr := range numbers {
|
||||
num, err := strconv.ParseFloat(numStr, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid coordinate: %s", numStr)
|
||||
}
|
||||
// Convert relative coordinates to absolute coordinates
|
||||
if i%2 == 0 { // x coordinates
|
||||
coords[i] = math.Round((num/DefaultFactor*float64(imageWidth))*10) / 10
|
||||
} else { // y coordinates
|
||||
coords[i] = math.Round((num/DefaultFactor*float64(imageHeight))*10) / 10
|
||||
}
|
||||
}
|
||||
actionInputs[paramName] = coords
|
||||
} else {
|
||||
actionInputs[paramName] = param
|
||||
}
|
||||
} else {
|
||||
// Handle other parameter types (content, key, direction, etc.)
|
||||
if paramName == "content" {
|
||||
// Handle escape characters
|
||||
param = strings.ReplaceAll(param, "\\n", "\n")
|
||||
param = strings.ReplaceAll(param, "\\\"", "\"")
|
||||
param = strings.ReplaceAll(param, "\\'", "'")
|
||||
}
|
||||
actionInputs[paramName] = param
|
||||
// Process each argument based on its type and context
|
||||
for paramName, paramValue := range rawArgs {
|
||||
processed, err := processArgument(paramName, paramValue, size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process argument %s: %w", paramName, err)
|
||||
}
|
||||
processedArgs[paramName] = processed
|
||||
}
|
||||
|
||||
return actionInputs, nil
|
||||
return processedArgs, nil
|
||||
}
|
||||
|
||||
// Process a single argument based on its name and value
|
||||
func processArgument(paramName string, paramValue interface{}, size types.Size) (interface{}, error) {
|
||||
// Handle coordinate parameters
|
||||
if isCoordinateParameter(paramName) {
|
||||
return normalizeActionCoordinates(paramValue, size)
|
||||
}
|
||||
|
||||
// Handle other parameter types (content, key, direction, etc.)
|
||||
return normalizeStringParam(paramName, paramValue), nil
|
||||
}
|
||||
|
||||
// Check if a parameter is a coordinate parameter
|
||||
func isCoordinateParameter(paramName string) bool {
|
||||
return strings.Contains(paramName, "box") || strings.Contains(paramName, "point")
|
||||
}
|
||||
|
||||
// normalizeActionCoordinates normalizes coordinates from various formats to actual pixel coordinates
|
||||
func normalizeActionCoordinates(coordData interface{}, size types.Size) ([]float64, error) {
|
||||
switch v := coordData.(type) {
|
||||
case []interface{}:
|
||||
// Handle JSON array format: [x1, y1, x2, y2] or [x1, y1]
|
||||
if len(v) < 2 {
|
||||
return nil, fmt.Errorf("coordinate array must have at least 2 elements, got %d", len(v))
|
||||
}
|
||||
|
||||
coords := make([]float64, len(v))
|
||||
for i, val := range v {
|
||||
switch num := val.(type) {
|
||||
case float64:
|
||||
// Convert relative coordinates to absolute coordinates using DefaultFactor
|
||||
if i%2 == 0 { // x coordinates
|
||||
coords[i] = convertRelativeToAbsolute(num, true, size)
|
||||
} else { // y coordinates
|
||||
coords[i] = convertRelativeToAbsolute(num, false, size)
|
||||
}
|
||||
case int:
|
||||
numFloat := float64(num)
|
||||
// Convert relative coordinates to absolute coordinates using DefaultFactor
|
||||
if i%2 == 0 { // x coordinates
|
||||
coords[i] = convertRelativeToAbsolute(numFloat, true, size)
|
||||
} else { // y coordinates
|
||||
coords[i] = convertRelativeToAbsolute(numFloat, false, size)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("coordinate value must be a number, got %T", val)
|
||||
}
|
||||
}
|
||||
return coords, nil
|
||||
|
||||
case []float64:
|
||||
// Handle already parsed float64 slice
|
||||
coords := make([]float64, len(v))
|
||||
for i, val := range v {
|
||||
if i%2 == 0 { // x coordinates
|
||||
coords[i] = convertRelativeToAbsolute(val, true, size)
|
||||
} else { // y coordinates
|
||||
coords[i] = convertRelativeToAbsolute(val, false, size)
|
||||
}
|
||||
}
|
||||
return coords, nil
|
||||
|
||||
case string:
|
||||
// Handle string format (from UI-TARS or string coordinates)
|
||||
return normalizeStringCoordinates(v, size)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported coordinate format: %T", coordData)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeStringParam normalizes string parameters, handling escape characters for content
|
||||
func normalizeStringParam(paramName string, paramValue interface{}) interface{} {
|
||||
if paramValue == nil {
|
||||
return paramValue
|
||||
}
|
||||
|
||||
// Convert to string if possible
|
||||
param, ok := paramValue.(string)
|
||||
if !ok {
|
||||
return paramValue // Return as-is if not a string
|
||||
}
|
||||
|
||||
param = strings.TrimSpace(param)
|
||||
if param == "" {
|
||||
return param
|
||||
}
|
||||
|
||||
// Handle escape characters for content parameter
|
||||
if paramName == "content" {
|
||||
param = strings.ReplaceAll(param, "\\n", "\n")
|
||||
param = strings.ReplaceAll(param, "\\\"", "\"")
|
||||
param = strings.ReplaceAll(param, "\\'", "'")
|
||||
}
|
||||
|
||||
return param
|
||||
}
|
||||
|
||||
// normalizeStringCoordinates normalizes coordinates from string format
|
||||
func normalizeStringCoordinates(coordStr string, size types.Size) ([]float64, error) {
|
||||
// check empty string
|
||||
if coordStr == "" {
|
||||
return nil, fmt.Errorf("empty coordinate string")
|
||||
}
|
||||
|
||||
// Apply coordinate format normalization using the shared function
|
||||
normalizedStr := normalizeCoordinatesFormat(coordStr)
|
||||
|
||||
// Extract numbers from the normalized string using regex
|
||||
re := regexp.MustCompile(`\d+`)
|
||||
numbers := re.FindAllString(normalizedStr, -1)
|
||||
if len(numbers) >= 2 {
|
||||
coords := make([]float64, len(numbers))
|
||||
for i, numStr := range numbers {
|
||||
num, err := strconv.ParseFloat(numStr, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid coordinate: %s", numStr)
|
||||
}
|
||||
// Convert relative coordinates to absolute coordinates
|
||||
if i%2 == 0 { // x coordinates
|
||||
coords[i] = convertRelativeToAbsolute(num, true, size)
|
||||
} else { // y coordinates
|
||||
coords[i] = convertRelativeToAbsolute(num, false, size)
|
||||
}
|
||||
}
|
||||
return coords, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid coordinate string format: %s", coordStr)
|
||||
}
|
||||
|
||||
// Action represents a parsed action with its context.
|
||||
type Action struct {
|
||||
ActionType string `json:"action_type"`
|
||||
ActionInputs map[string]any `json:"action_inputs"`
|
||||
}
|
||||
|
||||
// convertActionsToToolCalls converts actions to tool calls
|
||||
func (p *UITARSContentParser) convertActionsToToolCalls(actions []Action) []schema.ToolCall {
|
||||
// This is a shared function used by both JSONContentParser and UITARSContentParser
|
||||
func convertActionsToToolCalls(actions []Action) []schema.ToolCall {
|
||||
toolCalls := make([]schema.ToolCall, 0, len(actions))
|
||||
for _, action := range actions {
|
||||
jsonArgs, err := json.Marshal(action.ActionInputs)
|
||||
@@ -245,45 +377,3 @@ func (p *UITARSContentParser) convertActionsToToolCalls(actions []Action) []sche
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// Action represents a parsed action with its context.
|
||||
type Action struct {
|
||||
ActionType string `json:"action_type"`
|
||||
ActionInputs map[string]any `json:"action_inputs"`
|
||||
}
|
||||
|
||||
// ParseAction parses an action string into function name and arguments.
|
||||
func ParseAction(actionStr string) (*ParsedAction, error) {
|
||||
// Parse action type and parameters
|
||||
actionParts := strings.SplitN(actionStr, "(", 2)
|
||||
if len(actionParts) < 2 {
|
||||
return nil, fmt.Errorf("not a function call")
|
||||
}
|
||||
|
||||
funcName := strings.TrimSpace(actionParts[0])
|
||||
paramsText := strings.TrimSuffix(strings.TrimSpace(actionParts[1]), ")")
|
||||
|
||||
args := make(map[string]string)
|
||||
if paramsText != "" {
|
||||
// Split parameters by comma and parse key=value pairs
|
||||
for _, param := range strings.Split(paramsText, ",") {
|
||||
param = strings.TrimSpace(param)
|
||||
if strings.Contains(param, "=") {
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
// Remove surrounding quotes
|
||||
value = strings.Trim(value, "'\"")
|
||||
args[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ParsedAction{Function: funcName, Args: args}, nil
|
||||
}
|
||||
|
||||
// ParsedAction represents the result of parsing an action string.
|
||||
type ParsedAction struct {
|
||||
Function string
|
||||
Args map[string]string
|
||||
}
|
||||
|
||||
@@ -28,10 +28,9 @@ type PlanningOptions struct {
|
||||
// PlanningResult represents the result of planning
|
||||
type PlanningResult struct {
|
||||
ToolCalls []schema.ToolCall `json:"tool_calls"`
|
||||
Actions []Action `json:"actions"` // TODO: merge to ToolCalls
|
||||
ActionSummary string `json:"summary"`
|
||||
Thought string `json:"thought"`
|
||||
Text string `json:"text"`
|
||||
Content string `json:"content"` // original content from model
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -60,4 +60,57 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
`
|
||||
|
||||
// system prompt for JSONContentParser
|
||||
const defaultPlanningResponseJsonFormat = `You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.`
|
||||
const defaultPlanningResponseJsonFormat = `You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
Target: User will give you a screenshot, an instruction and some previous logs indicating what have been done. Please tell what the next one action is (or null if no action should be done) to do the tasks the instruction requires.
|
||||
|
||||
Restriction:
|
||||
- Don't give extra actions or plans beyond the instruction. ONLY plan for what the instruction requires. For example, don't try to submit the form if the instruction is only to fill something.
|
||||
- Always give ONLY ONE action in ` + "`log`" + ` field (or null if no action should be done), instead of multiple actions. Supported actions are click, long_press, type, scroll, drag, press_home, press_back, wait, finished.
|
||||
- Don't repeat actions in the previous logs.
|
||||
- Bbox is the bounding box of the element to be located. It's an array of 4 numbers, representing [x1, y1, x2, y2] coordinates in 1000x1000 relative coordinates system.
|
||||
|
||||
Supporting actions:
|
||||
- click: { action_type: "click", action_inputs: { startBox: [x1, y1, x2, y2] } }
|
||||
- long_press: { action_type: "long_press", action_inputs: { startBox: [x1, y1, x2, y2] } }
|
||||
- type: { action_type: "type", action_inputs: { content: string } } // If you want to submit your input, use "\\n" at the end of content.
|
||||
- scroll: { action_type: "scroll", action_inputs: { startBox: [x1, y1, x2, y2], direction: "down" | "up" | "left" | "right" } }
|
||||
- drag: { action_type: "drag", action_inputs: { startBox: [x1, y1, x2, y2], endBox: [x3, y3, x4, y4] } }
|
||||
- press_home: { action_type: "press_home", action_inputs: {} }
|
||||
- press_back: { action_type: "press_back", action_inputs: {} }
|
||||
- wait: { action_type: "wait", action_inputs: {} } // Sleep for 5s and take a screenshot to check for any changes.
|
||||
- finished: { action_type: "finished", action_inputs: { content: string } } // Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
Field description:
|
||||
* The ` + "`startBox`" + ` and ` + "`endBox`" + ` fields represent the bounding box coordinates of the target element in 1000x1000 relative coordinate system.
|
||||
* Use Chinese in log and summary fields.
|
||||
|
||||
Return in JSON format:
|
||||
{
|
||||
"actions": [
|
||||
{
|
||||
"action_type": "...",
|
||||
"action_inputs": { ... }
|
||||
}
|
||||
],
|
||||
"summary": "string", // Log what the next action you can do according to the screenshot and the instruction. Use Chinese.
|
||||
"error": "string" | null, // Error messages about unexpected situations, if any. Use Chinese.
|
||||
}
|
||||
|
||||
For example, when the instruction is "点击第二个帖子的作者头像", by viewing the screenshot, you should consider locating the second post's author avatar and output the JSON:
|
||||
|
||||
{
|
||||
"actions": [
|
||||
{
|
||||
"action_type": "click",
|
||||
"action_inputs": {
|
||||
"startBox": [100, 200, 150, 250]
|
||||
}
|
||||
}
|
||||
],
|
||||
"summary": "点击第二个帖子的作者头像",
|
||||
"error": null
|
||||
}
|
||||
|
||||
## User Instruction
|
||||
`
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestVLMPlanning(t *testing.T) {
|
||||
|
||||
userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标"
|
||||
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeDoubaoVL)
|
||||
require.NoError(t, err)
|
||||
|
||||
planner, err := NewPlanner(context.Background(), modelConfig)
|
||||
@@ -63,7 +63,7 @@ func TestVLMPlanning(t *testing.T) {
|
||||
toolCall := result.ToolCalls[0]
|
||||
assert.NotEmpty(t, toolCall.Function.Name)
|
||||
assert.NotEmpty(t, result.Thought)
|
||||
assert.NotEmpty(t, result.Text)
|
||||
assert.NotEmpty(t, result.Content)
|
||||
}
|
||||
|
||||
func TestXHSPlanning(t *testing.T) {
|
||||
@@ -100,13 +100,13 @@ func TestXHSPlanning(t *testing.T) {
|
||||
// 验证结果
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotEmpty(t, result.Actions)
|
||||
require.NotEmpty(t, result.ToolCalls)
|
||||
|
||||
// 验证动作
|
||||
action := result.Actions[0]
|
||||
assert.NotEmpty(t, action.ActionType)
|
||||
toolCall := result.ToolCalls[0]
|
||||
assert.NotEmpty(t, toolCall.Function.Name)
|
||||
assert.NotEmpty(t, result.Thought)
|
||||
assert.NotEmpty(t, result.Text)
|
||||
assert.NotEmpty(t, result.Content)
|
||||
}
|
||||
|
||||
func TestChatList(t *testing.T) {
|
||||
@@ -146,9 +146,7 @@ func TestChatList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleSwitch(t *testing.T) {
|
||||
userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关
|
||||
// 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启
|
||||
|
||||
userInstruction := "检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启"
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -159,9 +157,9 @@ func TestHandleSwitch(t *testing.T) {
|
||||
imageFile string
|
||||
actionType string
|
||||
}{
|
||||
{"testdata/deepseek_think_off.png", "finished"},
|
||||
{"testdata/deepseek_think_on.png", "finished"},
|
||||
{"testdata/deepseek_network_on.png", "finished"},
|
||||
{"testdata/deepseek_think_off.png", "click"}, // 关闭状态,需要点击开启
|
||||
{"testdata/deepseek_think_on.png", "click"}, // 关闭状态,需要点击开启
|
||||
{"testdata/deepseek_network_on.png", "finished"}, // 开启状态,无需操作
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -190,7 +188,7 @@ func TestHandleSwitch(t *testing.T) {
|
||||
// Validate results
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, result.Actions[0].ActionType, tc.actionType,
|
||||
require.Equal(t, result.ToolCalls[0].Function.Name, tc.actionType,
|
||||
"Unexpected action type for image file: %s", tc.imageFile)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user