From b9de3cf7a3f1bb0a81604c4968943269b81a5510 Mon Sep 17 00:00:00 2001 From: "lilong.129" Date: Sun, 8 Jun 2025 19:16:37 +0800 Subject: [PATCH] refactor: simplify AI action execution and improve sub-action handling --- internal/version/VERSION | 2 +- step_ui.go | 38 ++++++--- uixt/driver_ext_ai.go | 164 +++++++++++++++++-------------------- uixt/driver_ext_ai_test.go | 3 +- uixt/driver_session.go | 11 +++ uixt/driver_utils.go | 12 --- uixt/mcp_tools_ai.go | 16 ++-- uixt/sdk.go | 62 +------------- 8 files changed, 124 insertions(+), 184 deletions(-) diff --git a/internal/version/VERSION b/internal/version/VERSION index f4b5ba09..802023e6 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2506081005 +v5.0.0-beta-2506081916 diff --git a/step_ui.go b/step_ui.go index f6ce4914..ca2c430b 100644 --- a/step_ui.go +++ b/step_ui.go @@ -784,12 +784,13 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err }, StartTime: startTime.Unix(), } - if app, err1 := uiDriver.ForegroundInfo(); err1 == nil { - attachments["foreground_app"] = app.AppBaseInfo - } else { - log.Warn().Err(err1).Msg("save foreground app failed, ignore") + subActionResults, err1 := uiDriver.ExecuteAction( + context.Background(), actionResult.MobileAction) + if err1 != nil { + log.Warn().Err(err1).Msg("get foreground app failed, ignore") } actionResult.Elapsed = time.Since(startTime).Milliseconds() + actionResult.SubActions = subActionResults stepResult.Actions = append(stepResult.Actions, actionResult) } @@ -807,17 +808,16 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err }, StartTime: startTime.Unix(), } - if err2 := uiDriver.ClosePopupsHandler(); err2 != nil { - log.Error().Err(err2).Str("step", step.Name()).Msg("auto handle popup failed") + subActionResults, err2 := uiDriver.ExecuteAction( + context.Background(), actionResult.MobileAction) + if err2 != nil { + log.Warn().Err(err2).Str("step", step.Name()).Msg("auto handle popup failed") } actionResult.Elapsed = time.Since(startTime).Milliseconds() + actionResult.SubActions = subActionResults stepResult.Actions = append(stepResult.Actions, actionResult) } - // save attachments - for key, value := range uiDriver.GetData(true) { - attachments[key] = value - } stepResult.Attachments = attachments stepResult.Elapsed = time.Since(start).Milliseconds() }() @@ -907,7 +907,23 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err } }() - // action execution + // handle start_to_goal action + if action.Method == option.ACTION_StartToGoal { + subActionResults, err := uiDriver.StartToGoal(ctx, + action.Params.(string), action.GetOptions()...) + actionResult.Elapsed = time.Since(actionStartTime).Milliseconds() + actionResult.SubActions = subActionResults + stepResult.Actions = append(stepResult.Actions, actionResult) + if err != nil { + if !code.IsErrorPredefined(err) { + err = errors.Wrap(code.MobileUIDriverError, err.Error()) + } + return stepResult, err + } + continue + } + + // handle other actions subActionResults, err := uiDriver.ExecuteAction(ctx, action) actionResult.Elapsed = time.Since(actionStartTime).Milliseconds() actionResult.SubActions = subActionResults diff --git a/uixt/driver_ext_ai.go b/uixt/driver_ext_ai.go index 239c604a..91e5f6f6 100644 --- a/uixt/driver_ext_ai.go +++ b/uixt/driver_ext_ai.go @@ -3,7 +3,6 @@ package uixt import ( "context" "encoding/base64" - "strings" "time" "github.com/cloudwego/eino/schema" @@ -49,6 +48,11 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op Msg("LLM service request failed, retrying...") continue } + allSubActions = append(allSubActions, &SubActionResult{ + ActionName: "plan_next_action", + Arguments: prompt, + Error: err, + }) return allSubActions, err } @@ -59,10 +63,33 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op } // Invoke tool calls - subActions, err := dExt.invokeToolCalls(ctx, result.Thought, result.ToolCalls) - allSubActions = append(allSubActions, subActions...) - if err != nil { - return allSubActions, err + for _, toolCall := range result.ToolCalls { + // Check for context cancellation before each action + select { + case <-ctx.Done(): + log.Warn().Msg("interrupted in invokeToolCalls") + return allSubActions, errors.Wrap(code.InterruptError, "invokeToolCalls interrupted") + default: + } + + subActionStartTime := time.Now() + // Create sub-action result + subActionResult := &SubActionResult{ + ActionName: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + StartTime: subActionStartTime.Unix(), + Thought: result.Thought, + } + + if err := dExt.invokeToolCall(ctx, toolCall); err != nil { + subActionResult.Error = err + allSubActions = append(allSubActions, subActionResult) + return allSubActions, err + } + + // Collect sub-action specific attachments and reset session data + subActionResult.SessionData = dExt.GetSession().GetData(true) // reset after getting data + allSubActions = append(allSubActions, subActionResult) } if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes { @@ -71,22 +98,24 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op } } -func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) ([]*SubActionResult, error) { +func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) error { log.Info().Str("prompt", prompt).Msg("performing AI action") // plan next action result, err := dExt.PlanNextAction(ctx, prompt, opts...) if err != nil { - return nil, err + return err } // Invoke tool calls - subActionResults, err := dExt.invokeToolCalls(ctx, result.Thought, result.ToolCalls) - if err != nil { - return subActionResults, err + for _, toolCall := range result.ToolCalls { + err = dExt.invokeToolCall(ctx, toolCall) + if err != nil { + return err + } } - return subActionResults, nil + return nil } func (dExt *XTDriver) PlanNextAction(ctx context.Context, prompt string, opts ...option.ActionOption) (*ai.PlanningResult, error) { @@ -159,88 +188,49 @@ func (dExt *XTDriver) isTaskFinished(result *ai.PlanningResult) bool { return false } -// invokeToolCalls invokes the tool calls and returns sub-action results -func (dExt *XTDriver) invokeToolCalls(ctx context.Context, thought string, toolCalls []schema.ToolCall) ([]*SubActionResult, error) { - var subActionResults []*SubActionResult - - for _, action := range toolCalls { - // Check for context cancellation before each action - select { - case <-ctx.Done(): - log.Warn().Msg("interrupted in invokeToolCalls") - return subActionResults, errors.Wrap(code.InterruptError, "invokeToolCalls interrupted") - default: - } - - subActionStartTime := time.Now() - - // Extract action name (remove "uixt__" prefix) - actionName := strings.TrimPrefix(action.Function.Name, "uixt__") - - // Parse arguments - arguments := make(map[string]interface{}) - err := json.Unmarshal([]byte(action.Function.Arguments), &arguments) - if err != nil { - return subActionResults, err - } - - // Create sub-action result - subActionResult := &SubActionResult{ - ActionName: actionName, - Arguments: arguments, - StartTime: subActionStartTime.Unix(), - Thought: thought, - } - - // Execute the action - req := mcp.CallToolRequest{ - Params: struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments,omitempty"` - Meta *struct { - ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` - }{ - Name: action.Function.Name, - Arguments: arguments, - }, - } - - _, err = dExt.client.CallTool(ctx, req) - subActionResult.Elapsed = time.Since(subActionStartTime).Milliseconds() - if err != nil { - subActionResult.Error = err - subActionResults = append(subActionResults, subActionResult) - return subActionResults, err - } - - // Collect sub-action specific attachments and reset session data - subActionData := dExt.GetData(true) // reset after getting data - - // Add requests if any - if requests, ok := subActionData["requests"].([]*DriverRequests); ok && len(requests) > 0 { - subActionResult.Requests = requests - } - - // Add screen_results if any - if screenResults, ok := subActionData["screen_results"].([]*ScreenResult); ok && len(screenResults) > 0 { - subActionResult.ScreenResults = screenResults - } - - subActionResults = append(subActionResults, subActionResult) +// invokeToolCall invokes the tool call +func (dExt *XTDriver) invokeToolCall(ctx context.Context, toolCall schema.ToolCall) error { + // Parse arguments + arguments := make(map[string]interface{}) + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments) + if err != nil { + return err } - return subActionResults, nil + // Execute the action + req := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Name: toolCall.Function.Name, + Arguments: arguments, + }, + } + + _, err = dExt.client.CallTool(ctx, req) + if err != nil { + return err + } + + return nil } // SubActionResult represents a sub-action within a start_to_goal action type SubActionResult struct { - ActionName string `json:"action_name"` // name of the sub-action (e.g., "tap", "input") - Arguments interface{} `json:"arguments,omitempty"` // arguments passed to the sub-action - StartTime int64 `json:"start_time"` // sub-action start time - Elapsed int64 `json:"elapsed_ms"` // sub-action elapsed time(ms) - Error error `json:"error,omitempty"` // sub-action execution result - Thought string `json:"thought,omitempty"` // sub-action thought + ActionName string `json:"action_name"` // name of the sub-action (e.g., "tap", "input") + Arguments interface{} `json:"arguments,omitempty"` // arguments passed to the sub-action + StartTime int64 `json:"start_time"` // sub-action start time + Elapsed int64 `json:"elapsed_ms"` // sub-action elapsed time(ms) + Error error `json:"error,omitempty"` // sub-action execution result + Thought string `json:"thought,omitempty"` // sub-action thought + SessionData +} + +type SessionData struct { Requests []*DriverRequests `json:"requests,omitempty"` // store sub-action specific requests ScreenResults []*ScreenResult `json:"screen_results,omitempty"` // store sub-action specific screen_results } diff --git a/uixt/driver_ext_ai_test.go b/uixt/driver_ext_ai_test.go index e05844ad..b8c6d1ea 100644 --- a/uixt/driver_ext_ai_test.go +++ b/uixt/driver_ext_ai_test.go @@ -15,9 +15,8 @@ import ( func TestDriverExt_TapByLLM(t *testing.T) { driver := setupDriverExt(t) - subActionResults, err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像") + err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像") assert.Nil(t, err) - t.Log(subActionResults) err = driver.AIAssert("当前在个人介绍页") assert.Nil(t, err) diff --git a/uixt/driver_session.go b/uixt/driver_session.go index 2bc4107e..d3dba31b 100644 --- a/uixt/driver_session.go +++ b/uixt/driver_session.go @@ -76,6 +76,17 @@ func (s *DriverSession) Reset() { s.screenResults = make([]*ScreenResult, 0) } +func (s *DriverSession) GetData(withReset bool) SessionData { + sessionData := SessionData{ + Requests: s.History(), + ScreenResults: s.screenResults, + } + if withReset { + s.Reset() + } + return sessionData +} + func (s *DriverSession) SetBaseURL(baseUrl string) { s.baseUrl = baseUrl } diff --git a/uixt/driver_utils.go b/uixt/driver_utils.go index d138d030..b5c007bf 100644 --- a/uixt/driver_utils.go +++ b/uixt/driver_utils.go @@ -112,18 +112,6 @@ func (dExt *XTDriver) Setup() error { return nil } -func (dExt *XTDriver) GetData(withReset bool) map[string]interface{} { - session := dExt.GetSession() - data := map[string]interface{}{ - "requests": session.History(), - "screen_results": session.screenResults, - } - if withReset { - session.Reset() - } - return data -} - func (dExt *XTDriver) assertOCR(text, assert string) error { var opts []option.ActionOption opts = append(opts, option.WithScreenShotFileName(fmt.Sprintf("assert_ocr_%s", text))) diff --git a/uixt/mcp_tools_ai.go b/uixt/mcp_tools_ai.go index f5bd7027..0e1f4b5c 100644 --- a/uixt/mcp_tools_ai.go +++ b/uixt/mcp_tools_ai.go @@ -13,8 +13,7 @@ import ( // ToolStartToGoal implements the start_to_goal tool call. type ToolStartToGoal struct { // Return data fields - these define the structure of data returned by this tool - Prompt string `json:"prompt" desc:"Goal prompt that was executed"` - SubActions []*SubActionResult `json:"sub_actions" desc:"Sub-actions that were executed"` + Prompt string `json:"prompt" desc:"Goal prompt that was executed"` } func (t *ToolStartToGoal) Name() option.ActionName { @@ -43,15 +42,14 @@ func (t *ToolStartToGoal) Implement() server.ToolHandlerFunc { } // Start to goal logic - subActionResults, err := driverExt.StartToGoal(ctx, unifiedReq.Prompt) + _, err = driverExt.StartToGoal(ctx, unifiedReq.Prompt) if err != nil { return NewMCPErrorResponse(fmt.Sprintf("Failed to achieve goal: %s", err.Error())), nil } message := fmt.Sprintf("Successfully achieved goal: %s", unifiedReq.Prompt) returnData := ToolStartToGoal{ - Prompt: unifiedReq.Prompt, - SubActions: subActionResults, + Prompt: unifiedReq.Prompt, } return NewMCPSuccessResponse(message, &returnData), nil @@ -75,8 +73,7 @@ func (t *ToolStartToGoal) ConvertActionToCallToolRequest(action option.MobileAct // ToolAIAction implements the ai_action tool call. type ToolAIAction struct { // Return data fields - these define the structure of data returned by this tool - Prompt string `json:"prompt" desc:"AI action prompt that was executed"` - SubActions []*SubActionResult `json:"sub_actions" desc:"Sub-actions that were executed"` + Prompt string `json:"prompt" desc:"AI action prompt that was executed"` } func (t *ToolAIAction) Name() option.ActionName { @@ -105,15 +102,14 @@ func (t *ToolAIAction) Implement() server.ToolHandlerFunc { } // AI action logic - subActionResults, err := driverExt.AIAction(ctx, unifiedReq.Prompt) + err = driverExt.AIAction(ctx, unifiedReq.Prompt) if err != nil { return NewMCPErrorResponse(fmt.Sprintf("AI action failed: %s", err.Error())), nil } message := fmt.Sprintf("Successfully performed AI action with prompt: %s", unifiedReq.Prompt) returnData := ToolAIAction{ - Prompt: unifiedReq.Prompt, - SubActions: subActionResults, + Prompt: unifiedReq.Prompt, } return NewMCPSuccessResponse(message, &returnData), nil diff --git a/uixt/sdk.go b/uixt/sdk.go index 6d1548fc..caf96bf5 100644 --- a/uixt/sdk.go +++ b/uixt/sdk.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/ai" "github.com/httprunner/httprunner/v5/uixt/option" "github.com/mark3labs/mcp-go/client" @@ -133,73 +132,14 @@ func (dExt *XTDriver) ExecuteAction(ctx context.Context, action option.MobileAct return []*SubActionResult{subActionResult}, err } - // Handle special AI actions (start_to_goal, ai_action) that return sub-actions - if action.Method == option.ACTION_StartToGoal || action.Method == option.ACTION_AIAction { - return dExt.parseAIActionResult(result, subActionResult) - } - // For regular actions, collect session data and return single sub-action result - subActionData := dExt.GetData(true) // reset after getting data - - // Add requests if any - if requests, ok := subActionData["requests"].([]*DriverRequests); ok && len(requests) > 0 { - subActionResult.Requests = requests - } - - // Add screen_results if any - if screenResults, ok := subActionData["screen_results"].([]*ScreenResult); ok && len(screenResults) > 0 { - subActionResult.ScreenResults = screenResults - } + subActionResult.SessionData = dExt.GetSession().GetData(true) // reset after getting data log.Debug().Str("tool", string(tool.Name())). Msg("execute action via MCP tool") return []*SubActionResult{subActionResult}, nil } -// parseAIActionResult parses the result from AI actions (start_to_goal, ai_action) and extracts sub-actions -func (dExt *XTDriver) parseAIActionResult(result *mcp.CallToolResult, originalSubAction *SubActionResult) ([]*SubActionResult, error) { - // Parse the JSON response to extract sub_actions - var responseData map[string]interface{} - if len(result.Content) > 0 { - // Get the first text content - if textContent, ok := result.Content[0].(mcp.TextContent); ok { - if err := json.Unmarshal([]byte(textContent.Text), &responseData); err != nil { - log.Warn().Err(err).Msg("failed to parse AI action result, falling back to single action") - return []*SubActionResult{originalSubAction}, nil - } - } else { - log.Warn().Msg("AI action result is not text content, falling back to single action") - return []*SubActionResult{originalSubAction}, nil - } - } - - // Extract sub_actions from the response - if subActionsData, ok := responseData["sub_actions"]; ok { - // Convert to JSON and back to properly deserialize SubActionResult structs - subActionsJSON, err := json.Marshal(subActionsData) - if err != nil { - log.Warn().Err(err).Msg("failed to marshal sub_actions, falling back to single action") - return []*SubActionResult{originalSubAction}, nil - } - - var subActionResults []*SubActionResult - if err := json.Unmarshal(subActionsJSON, &subActionResults); err != nil { - log.Warn().Err(err).Msg("failed to unmarshal sub_actions, falling back to single action") - return []*SubActionResult{originalSubAction}, nil - } - - log.Debug().Int("sub_actions_count", len(subActionResults)). - Str("action", string(originalSubAction.ActionName)). - Msg("parsed AI action sub-actions") - return subActionResults, nil - } - - // If no sub_actions found, return the original action as a single result - log.Debug().Str("action", string(originalSubAction.ActionName)). - Msg("no sub_actions found in AI action result, using single action") - return []*SubActionResult{originalSubAction}, nil -} - // NewDeviceWithDefault is a helper function to create a device with default options func NewDeviceWithDefault(platform, serial string) (device IDevice, err error) { if serial == "" {