refactor: simplify AI action execution and improve sub-action handling

This commit is contained in:
lilong.129
2025-06-08 19:16:37 +08:00
parent bdf64a08aa
commit b9de3cf7a3
8 changed files with 124 additions and 184 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2506081005
v5.0.0-beta-2506081916

View File

@@ -784,12 +784,13 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
},
StartTime: startTime.Unix(),
}
if app, err1 := uiDriver.ForegroundInfo(); err1 == nil {
attachments["foreground_app"] = app.AppBaseInfo
} else {
log.Warn().Err(err1).Msg("save foreground app failed, ignore")
subActionResults, err1 := uiDriver.ExecuteAction(
context.Background(), actionResult.MobileAction)
if err1 != nil {
log.Warn().Err(err1).Msg("get foreground app failed, ignore")
}
actionResult.Elapsed = time.Since(startTime).Milliseconds()
actionResult.SubActions = subActionResults
stepResult.Actions = append(stepResult.Actions, actionResult)
}
@@ -807,17 +808,16 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
},
StartTime: startTime.Unix(),
}
if err2 := uiDriver.ClosePopupsHandler(); err2 != nil {
log.Error().Err(err2).Str("step", step.Name()).Msg("auto handle popup failed")
subActionResults, err2 := uiDriver.ExecuteAction(
context.Background(), actionResult.MobileAction)
if err2 != nil {
log.Warn().Err(err2).Str("step", step.Name()).Msg("auto handle popup failed")
}
actionResult.Elapsed = time.Since(startTime).Milliseconds()
actionResult.SubActions = subActionResults
stepResult.Actions = append(stepResult.Actions, actionResult)
}
// save attachments
for key, value := range uiDriver.GetData(true) {
attachments[key] = value
}
stepResult.Attachments = attachments
stepResult.Elapsed = time.Since(start).Milliseconds()
}()
@@ -907,7 +907,23 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
}
}()
// action execution
// handle start_to_goal action
if action.Method == option.ACTION_StartToGoal {
subActionResults, err := uiDriver.StartToGoal(ctx,
action.Params.(string), action.GetOptions()...)
actionResult.Elapsed = time.Since(actionStartTime).Milliseconds()
actionResult.SubActions = subActionResults
stepResult.Actions = append(stepResult.Actions, actionResult)
if err != nil {
if !code.IsErrorPredefined(err) {
err = errors.Wrap(code.MobileUIDriverError, err.Error())
}
return stepResult, err
}
continue
}
// handle other actions
subActionResults, err := uiDriver.ExecuteAction(ctx, action)
actionResult.Elapsed = time.Since(actionStartTime).Milliseconds()
actionResult.SubActions = subActionResults

View File

@@ -3,7 +3,6 @@ package uixt
import (
"context"
"encoding/base64"
"strings"
"time"
"github.com/cloudwego/eino/schema"
@@ -49,6 +48,11 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op
Msg("LLM service request failed, retrying...")
continue
}
allSubActions = append(allSubActions, &SubActionResult{
ActionName: "plan_next_action",
Arguments: prompt,
Error: err,
})
return allSubActions, err
}
@@ -59,10 +63,33 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op
}
// Invoke tool calls
subActions, err := dExt.invokeToolCalls(ctx, result.Thought, result.ToolCalls)
allSubActions = append(allSubActions, subActions...)
if err != nil {
return allSubActions, err
for _, toolCall := range result.ToolCalls {
// Check for context cancellation before each action
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in invokeToolCalls")
return allSubActions, errors.Wrap(code.InterruptError, "invokeToolCalls interrupted")
default:
}
subActionStartTime := time.Now()
// Create sub-action result
subActionResult := &SubActionResult{
ActionName: toolCall.Function.Name,
Arguments: toolCall.Function.Arguments,
StartTime: subActionStartTime.Unix(),
Thought: result.Thought,
}
if err := dExt.invokeToolCall(ctx, toolCall); err != nil {
subActionResult.Error = err
allSubActions = append(allSubActions, subActionResult)
return allSubActions, err
}
// Collect sub-action specific attachments and reset session data
subActionResult.SessionData = dExt.GetSession().GetData(true) // reset after getting data
allSubActions = append(allSubActions, subActionResult)
}
if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes {
@@ -71,22 +98,24 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op
}
}
func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) ([]*SubActionResult, error) {
func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) error {
log.Info().Str("prompt", prompt).Msg("performing AI action")
// plan next action
result, err := dExt.PlanNextAction(ctx, prompt, opts...)
if err != nil {
return nil, err
return err
}
// Invoke tool calls
subActionResults, err := dExt.invokeToolCalls(ctx, result.Thought, result.ToolCalls)
if err != nil {
return subActionResults, err
for _, toolCall := range result.ToolCalls {
err = dExt.invokeToolCall(ctx, toolCall)
if err != nil {
return err
}
}
return subActionResults, nil
return nil
}
func (dExt *XTDriver) PlanNextAction(ctx context.Context, prompt string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
@@ -159,88 +188,49 @@ func (dExt *XTDriver) isTaskFinished(result *ai.PlanningResult) bool {
return false
}
// invokeToolCalls invokes the tool calls and returns sub-action results
func (dExt *XTDriver) invokeToolCalls(ctx context.Context, thought string, toolCalls []schema.ToolCall) ([]*SubActionResult, error) {
var subActionResults []*SubActionResult
for _, action := range toolCalls {
// Check for context cancellation before each action
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in invokeToolCalls")
return subActionResults, errors.Wrap(code.InterruptError, "invokeToolCalls interrupted")
default:
}
subActionStartTime := time.Now()
// Extract action name (remove "uixt__" prefix)
actionName := strings.TrimPrefix(action.Function.Name, "uixt__")
// Parse arguments
arguments := make(map[string]interface{})
err := json.Unmarshal([]byte(action.Function.Arguments), &arguments)
if err != nil {
return subActionResults, err
}
// Create sub-action result
subActionResult := &SubActionResult{
ActionName: actionName,
Arguments: arguments,
StartTime: subActionStartTime.Unix(),
Thought: thought,
}
// Execute the action
req := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Meta *struct {
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
}{
Name: action.Function.Name,
Arguments: arguments,
},
}
_, err = dExt.client.CallTool(ctx, req)
subActionResult.Elapsed = time.Since(subActionStartTime).Milliseconds()
if err != nil {
subActionResult.Error = err
subActionResults = append(subActionResults, subActionResult)
return subActionResults, err
}
// Collect sub-action specific attachments and reset session data
subActionData := dExt.GetData(true) // reset after getting data
// Add requests if any
if requests, ok := subActionData["requests"].([]*DriverRequests); ok && len(requests) > 0 {
subActionResult.Requests = requests
}
// Add screen_results if any
if screenResults, ok := subActionData["screen_results"].([]*ScreenResult); ok && len(screenResults) > 0 {
subActionResult.ScreenResults = screenResults
}
subActionResults = append(subActionResults, subActionResult)
// invokeToolCall invokes the tool call
func (dExt *XTDriver) invokeToolCall(ctx context.Context, toolCall schema.ToolCall) error {
// Parse arguments
arguments := make(map[string]interface{})
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments)
if err != nil {
return err
}
return subActionResults, nil
// Execute the action
req := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Meta *struct {
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
}{
Name: toolCall.Function.Name,
Arguments: arguments,
},
}
_, err = dExt.client.CallTool(ctx, req)
if err != nil {
return err
}
return nil
}
// SubActionResult represents a sub-action within a start_to_goal action
type SubActionResult struct {
ActionName string `json:"action_name"` // name of the sub-action (e.g., "tap", "input")
Arguments interface{} `json:"arguments,omitempty"` // arguments passed to the sub-action
StartTime int64 `json:"start_time"` // sub-action start time
Elapsed int64 `json:"elapsed_ms"` // sub-action elapsed time(ms)
Error error `json:"error,omitempty"` // sub-action execution result
Thought string `json:"thought,omitempty"` // sub-action thought
ActionName string `json:"action_name"` // name of the sub-action (e.g., "tap", "input")
Arguments interface{} `json:"arguments,omitempty"` // arguments passed to the sub-action
StartTime int64 `json:"start_time"` // sub-action start time
Elapsed int64 `json:"elapsed_ms"` // sub-action elapsed time(ms)
Error error `json:"error,omitempty"` // sub-action execution result
Thought string `json:"thought,omitempty"` // sub-action thought
SessionData
}
type SessionData struct {
Requests []*DriverRequests `json:"requests,omitempty"` // store sub-action specific requests
ScreenResults []*ScreenResult `json:"screen_results,omitempty"` // store sub-action specific screen_results
}

View File

@@ -15,9 +15,8 @@ import (
func TestDriverExt_TapByLLM(t *testing.T) {
driver := setupDriverExt(t)
subActionResults, err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像")
err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像")
assert.Nil(t, err)
t.Log(subActionResults)
err = driver.AIAssert("当前在个人介绍页")
assert.Nil(t, err)

View File

@@ -76,6 +76,17 @@ func (s *DriverSession) Reset() {
s.screenResults = make([]*ScreenResult, 0)
}
func (s *DriverSession) GetData(withReset bool) SessionData {
sessionData := SessionData{
Requests: s.History(),
ScreenResults: s.screenResults,
}
if withReset {
s.Reset()
}
return sessionData
}
func (s *DriverSession) SetBaseURL(baseUrl string) {
s.baseUrl = baseUrl
}

View File

@@ -112,18 +112,6 @@ func (dExt *XTDriver) Setup() error {
return nil
}
func (dExt *XTDriver) GetData(withReset bool) map[string]interface{} {
session := dExt.GetSession()
data := map[string]interface{}{
"requests": session.History(),
"screen_results": session.screenResults,
}
if withReset {
session.Reset()
}
return data
}
func (dExt *XTDriver) assertOCR(text, assert string) error {
var opts []option.ActionOption
opts = append(opts, option.WithScreenShotFileName(fmt.Sprintf("assert_ocr_%s", text)))

View File

@@ -13,8 +13,7 @@ import (
// ToolStartToGoal implements the start_to_goal tool call.
type ToolStartToGoal struct {
// Return data fields - these define the structure of data returned by this tool
Prompt string `json:"prompt" desc:"Goal prompt that was executed"`
SubActions []*SubActionResult `json:"sub_actions" desc:"Sub-actions that were executed"`
Prompt string `json:"prompt" desc:"Goal prompt that was executed"`
}
func (t *ToolStartToGoal) Name() option.ActionName {
@@ -43,15 +42,14 @@ func (t *ToolStartToGoal) Implement() server.ToolHandlerFunc {
}
// Start to goal logic
subActionResults, err := driverExt.StartToGoal(ctx, unifiedReq.Prompt)
_, err = driverExt.StartToGoal(ctx, unifiedReq.Prompt)
if err != nil {
return NewMCPErrorResponse(fmt.Sprintf("Failed to achieve goal: %s", err.Error())), nil
}
message := fmt.Sprintf("Successfully achieved goal: %s", unifiedReq.Prompt)
returnData := ToolStartToGoal{
Prompt: unifiedReq.Prompt,
SubActions: subActionResults,
Prompt: unifiedReq.Prompt,
}
return NewMCPSuccessResponse(message, &returnData), nil
@@ -75,8 +73,7 @@ func (t *ToolStartToGoal) ConvertActionToCallToolRequest(action option.MobileAct
// ToolAIAction implements the ai_action tool call.
type ToolAIAction struct {
// Return data fields - these define the structure of data returned by this tool
Prompt string `json:"prompt" desc:"AI action prompt that was executed"`
SubActions []*SubActionResult `json:"sub_actions" desc:"Sub-actions that were executed"`
Prompt string `json:"prompt" desc:"AI action prompt that was executed"`
}
func (t *ToolAIAction) Name() option.ActionName {
@@ -105,15 +102,14 @@ func (t *ToolAIAction) Implement() server.ToolHandlerFunc {
}
// AI action logic
subActionResults, err := driverExt.AIAction(ctx, unifiedReq.Prompt)
err = driverExt.AIAction(ctx, unifiedReq.Prompt)
if err != nil {
return NewMCPErrorResponse(fmt.Sprintf("AI action failed: %s", err.Error())), nil
}
message := fmt.Sprintf("Successfully performed AI action with prompt: %s", unifiedReq.Prompt)
returnData := ToolAIAction{
Prompt: unifiedReq.Prompt,
SubActions: subActionResults,
Prompt: unifiedReq.Prompt,
}
return NewMCPSuccessResponse(message, &returnData), nil

View File

@@ -6,7 +6,6 @@ import (
"strings"
"time"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/mark3labs/mcp-go/client"
@@ -133,73 +132,14 @@ func (dExt *XTDriver) ExecuteAction(ctx context.Context, action option.MobileAct
return []*SubActionResult{subActionResult}, err
}
// Handle special AI actions (start_to_goal, ai_action) that return sub-actions
if action.Method == option.ACTION_StartToGoal || action.Method == option.ACTION_AIAction {
return dExt.parseAIActionResult(result, subActionResult)
}
// For regular actions, collect session data and return single sub-action result
subActionData := dExt.GetData(true) // reset after getting data
// Add requests if any
if requests, ok := subActionData["requests"].([]*DriverRequests); ok && len(requests) > 0 {
subActionResult.Requests = requests
}
// Add screen_results if any
if screenResults, ok := subActionData["screen_results"].([]*ScreenResult); ok && len(screenResults) > 0 {
subActionResult.ScreenResults = screenResults
}
subActionResult.SessionData = dExt.GetSession().GetData(true) // reset after getting data
log.Debug().Str("tool", string(tool.Name())).
Msg("execute action via MCP tool")
return []*SubActionResult{subActionResult}, nil
}
// parseAIActionResult parses the result from AI actions (start_to_goal, ai_action) and extracts sub-actions
func (dExt *XTDriver) parseAIActionResult(result *mcp.CallToolResult, originalSubAction *SubActionResult) ([]*SubActionResult, error) {
// Parse the JSON response to extract sub_actions
var responseData map[string]interface{}
if len(result.Content) > 0 {
// Get the first text content
if textContent, ok := result.Content[0].(mcp.TextContent); ok {
if err := json.Unmarshal([]byte(textContent.Text), &responseData); err != nil {
log.Warn().Err(err).Msg("failed to parse AI action result, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
} else {
log.Warn().Msg("AI action result is not text content, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
}
// Extract sub_actions from the response
if subActionsData, ok := responseData["sub_actions"]; ok {
// Convert to JSON and back to properly deserialize SubActionResult structs
subActionsJSON, err := json.Marshal(subActionsData)
if err != nil {
log.Warn().Err(err).Msg("failed to marshal sub_actions, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
var subActionResults []*SubActionResult
if err := json.Unmarshal(subActionsJSON, &subActionResults); err != nil {
log.Warn().Err(err).Msg("failed to unmarshal sub_actions, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
log.Debug().Int("sub_actions_count", len(subActionResults)).
Str("action", string(originalSubAction.ActionName)).
Msg("parsed AI action sub-actions")
return subActionResults, nil
}
// If no sub_actions found, return the original action as a single result
log.Debug().Str("action", string(originalSubAction.ActionName)).
Msg("no sub_actions found in AI action result, using single action")
return []*SubActionResult{originalSubAction}, nil
}
// NewDeviceWithDefault is a helper function to create a device with default options
func NewDeviceWithDefault(platform, serial string) (device IDevice, err error) {
if serial == "" {