refactor: unify action execution interface and merge AI action handling

This commit is contained in:
lilong.129
2025-06-07 23:59:07 +08:00
parent fcf3009c67
commit ec4f1eb68a
8 changed files with 199 additions and 65 deletions

View File

@@ -3,13 +3,12 @@ package uixt
import (
"context"
"encoding/base64"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/builtin"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
@@ -18,10 +17,11 @@ import (
"github.com/rs/zerolog/log"
)
func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...option.ActionOption) error {
func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...option.ActionOption) ([]*SubActionResult, error) {
options := option.NewActionOptions(opts...)
log.Info().Int("max_retry_times", options.MaxRetryTimes).Msg("StartToGoal")
var allSubActions []*SubActionResult
var attempt int
for {
attempt++
@@ -31,7 +31,7 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in StartToGoal")
return errors.Wrap(code.InterruptError, "StartToGoal interrupted")
return allSubActions, errors.Wrap(code.InterruptError, "StartToGoal interrupted")
default:
}
@@ -49,37 +49,44 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op
Msg("LLM service request failed, retrying...")
continue
}
return err
return allSubActions, err
}
// Check if task is finished BEFORE executing actions
if dExt.isTaskFinished(result) {
log.Info().Msg("task finished, stopping StartToGoal")
return nil
return allSubActions, nil
}
// Execute actions only if task is not finished
if err := dExt.executeActions(ctx, result.ToolCalls); err != nil {
return err
// Invoke tool calls
subActions, err := dExt.invokeToolCalls(ctx, result.Thought, result.ToolCalls)
allSubActions = append(allSubActions, subActions...)
if err != nil {
return allSubActions, err
}
if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes {
return errors.New("reached max retry times")
return allSubActions, errors.New("reached max retry times")
}
}
}
func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) error {
func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) ([]*SubActionResult, error) {
log.Info().Str("prompt", prompt).Msg("performing AI action")
// plan next action
result, err := dExt.PlanNextAction(ctx, prompt, opts...)
if err != nil {
return err
return nil, err
}
// execute actions
return dExt.executeActions(ctx, result.ToolCalls)
// Invoke tool calls
subActionResults, err := dExt.invokeToolCalls(ctx, result.Thought, result.ToolCalls)
if err != nil {
return subActionResults, err
}
return subActionResults, nil
}
func (dExt *XTDriver) PlanNextAction(ctx context.Context, prompt string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
@@ -87,36 +94,28 @@ func (dExt *XTDriver) PlanNextAction(ctx context.Context, prompt string, opts ..
return nil, errors.New("LLM service is not initialized")
}
compressedBufSource, err := getScreenShotBuffer(dExt.IDriver)
// Parse action options to get ResetHistory setting
options := option.NewActionOptions(opts...)
resetHistory := options.ResetHistory
// Use GetScreenResult to handle screenshot capture, save, and session tracking
screenResult, err := dExt.GetScreenResult(
option.WithScreenShotFileName(builtin.GenNameWithTimestamp("%d_screenshot")),
)
if err != nil {
return nil, err
}
// convert buffer to base64 string
// convert buffer to base64 string for LLM
screenShotBase64 := "data:image/jpeg;base64," +
base64.StdEncoding.EncodeToString(compressedBufSource.Bytes())
// save screenshot to file
imagePath := filepath.Join(
config.GetConfig().ScreenShotsPath,
fmt.Sprintf("%s.jpeg", builtin.GenNameWithTimestamp("%d_screenshot")),
)
go func() {
err := saveScreenShot(compressedBufSource, imagePath)
if err != nil {
log.Error().Err(err).Msg("save screenshot file failed")
}
}()
base64.StdEncoding.EncodeToString(screenResult.bufSource.Bytes())
// get window size
size, err := dExt.IDriver.WindowSize()
if err != nil {
return nil, errors.Wrap(code.DeviceGetInfoError, err.Error())
}
// Parse action options to get ResetHistory setting
options := option.NewActionOptions(opts...)
resetHistory := options.ResetHistory
planningOpts := &ai.PlanningOptions{
UserInstruction: prompt,
Message: &schema.Message{
@@ -160,23 +159,40 @@ func (dExt *XTDriver) isTaskFinished(result *ai.PlanningResult) bool {
return false
}
// executeActions executes the planned actions
func (dExt *XTDriver) executeActions(ctx context.Context, toolCalls []schema.ToolCall) error {
// invokeToolCalls invokes the tool calls and returns sub-action results
func (dExt *XTDriver) invokeToolCalls(ctx context.Context, thought string, toolCalls []schema.ToolCall) ([]*SubActionResult, error) {
var subActionResults []*SubActionResult
for _, action := range toolCalls {
// Check for context cancellation before each action
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in executeActions")
return errors.Wrap(code.InterruptError, "executeActions interrupted")
log.Warn().Msg("interrupted in invokeToolCalls")
return subActionResults, errors.Wrap(code.InterruptError, "invokeToolCalls interrupted")
default:
}
// call eino tool
subActionStartTime := time.Now()
// Extract action name (remove "uixt__" prefix)
actionName := strings.TrimPrefix(action.Function.Name, "uixt__")
// Parse arguments
arguments := make(map[string]interface{})
err := json.Unmarshal([]byte(action.Function.Arguments), &arguments)
if err != nil {
return err
return subActionResults, err
}
// Create sub-action result
subActionResult := &SubActionResult{
ActionName: actionName,
Arguments: arguments,
StartTime: subActionStartTime.Unix(),
Thought: thought,
}
// Execute the action
req := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
@@ -191,12 +207,42 @@ func (dExt *XTDriver) executeActions(ctx context.Context, toolCalls []schema.Too
}
_, err = dExt.client.CallTool(ctx, req)
subActionResult.Elapsed = time.Since(subActionStartTime).Milliseconds()
if err != nil {
return err
subActionResult.Error = err
subActionResults = append(subActionResults, subActionResult)
return subActionResults, err
}
// Collect sub-action specific attachments and reset session data
subActionData := dExt.GetData(true) // reset after getting data
// Add requests if any
if requests, ok := subActionData["requests"].([]*DriverRequests); ok && len(requests) > 0 {
subActionResult.Requests = requests
}
// Add screen_results if any
if screenResults, ok := subActionData["screen_results"].([]*ScreenResult); ok && len(screenResults) > 0 {
subActionResult.ScreenResults = screenResults
}
subActionResults = append(subActionResults, subActionResult)
}
return nil
return subActionResults, nil
}
// SubActionResult represents a sub-action within a start_to_goal action
type SubActionResult struct {
ActionName string `json:"action_name"` // name of the sub-action (e.g., "tap", "input")
Arguments interface{} `json:"arguments,omitempty"` // arguments passed to the sub-action
StartTime int64 `json:"start_time"` // sub-action start time
Elapsed int64 `json:"elapsed_ms"` // sub-action elapsed time(ms)
Error error `json:"error,omitempty"` // sub-action execution result
Thought string `json:"thought,omitempty"` // sub-action thought
Requests []*DriverRequests `json:"requests,omitempty"` // store sub-action specific requests
ScreenResults []*ScreenResult `json:"screen_results,omitempty"` // store sub-action specific screen_results
}
func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (string, error) {

View File

@@ -15,8 +15,9 @@ import (
func TestDriverExt_TapByLLM(t *testing.T) {
driver := setupDriverExt(t)
err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像")
subActionResults, err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像")
assert.Nil(t, err)
t.Log(subActionResults)
err = driver.AIAssert("当前在个人介绍页")
assert.Nil(t, err)
@@ -46,7 +47,7 @@ func TestDriverExt_StartToGoal(t *testing.T) {
userInstruction += "\n\n请严格按照以上游戏规则开始游戏注意请只做点击操作"
err := driver.StartToGoal(context.Background(), userInstruction)
_, err := driver.StartToGoal(context.Background(), userInstruction)
assert.Nil(t, err)
}

View File

@@ -13,7 +13,8 @@ import (
// ToolStartToGoal implements the start_to_goal tool call.
type ToolStartToGoal struct {
// Return data fields - these define the structure of data returned by this tool
Prompt string `json:"prompt" desc:"Goal prompt that was executed"`
Prompt string `json:"prompt" desc:"Goal prompt that was executed"`
SubActions []*SubActionResult `json:"sub_actions" desc:"Sub-actions that were executed"`
}
func (t *ToolStartToGoal) Name() option.ActionName {
@@ -42,14 +43,15 @@ func (t *ToolStartToGoal) Implement() server.ToolHandlerFunc {
}
// Start to goal logic
err = driverExt.StartToGoal(ctx, unifiedReq.Prompt)
subActionResults, err := driverExt.StartToGoal(ctx, unifiedReq.Prompt)
if err != nil {
return NewMCPErrorResponse(fmt.Sprintf("Failed to achieve goal: %s", err.Error())), nil
}
message := fmt.Sprintf("Successfully achieved goal: %s", unifiedReq.Prompt)
returnData := ToolStartToGoal{
Prompt: unifiedReq.Prompt,
Prompt: unifiedReq.Prompt,
SubActions: subActionResults,
}
return NewMCPSuccessResponse(message, &returnData), nil
@@ -73,7 +75,8 @@ func (t *ToolStartToGoal) ConvertActionToCallToolRequest(action option.MobileAct
// ToolAIAction implements the ai_action tool call.
type ToolAIAction struct {
// Return data fields - these define the structure of data returned by this tool
Prompt string `json:"prompt" desc:"AI action prompt that was executed"`
Prompt string `json:"prompt" desc:"AI action prompt that was executed"`
SubActions []*SubActionResult `json:"sub_actions" desc:"Sub-actions that were executed"`
}
func (t *ToolAIAction) Name() option.ActionName {
@@ -102,14 +105,15 @@ func (t *ToolAIAction) Implement() server.ToolHandlerFunc {
}
// AI action logic
err = driverExt.AIAction(ctx, unifiedReq.Prompt)
subActionResults, err := driverExt.AIAction(ctx, unifiedReq.Prompt)
if err != nil {
return NewMCPErrorResponse(fmt.Sprintf("AI action failed: %s", err.Error())), nil
}
message := fmt.Sprintf("Successfully performed AI action with prompt: %s", unifiedReq.Prompt)
returnData := ToolAIAction{
Prompt: unifiedReq.Prompt,
Prompt: unifiedReq.Prompt,
SubActions: subActionResults,
}
return NewMCPSuccessResponse(message, &returnData), nil

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"strings"
"time"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/mark3labs/mcp-go/client"
@@ -88,37 +90,114 @@ func (c *MCPClient4XTDriver) GetToolByAction(actionName option.ActionName) Actio
return c.Server.GetToolByAction(actionName)
}
func (dExt *XTDriver) ExecuteAction(ctx context.Context, action option.MobileAction) (err error) {
func (dExt *XTDriver) ExecuteAction(ctx context.Context, action option.MobileAction) ([]*SubActionResult, error) {
subActionStartTime := time.Now()
// Find the corresponding tool for this action method
tool := dExt.client.Server.GetToolByAction(action.Method)
if tool == nil {
return fmt.Errorf("no tool found for action method: %s", action.Method)
return nil, fmt.Errorf("no tool found for action method: %s", action.Method)
}
// Use the tool's own conversion method
req, err := tool.ConvertActionToCallToolRequest(action)
if err != nil {
return fmt.Errorf("failed to convert action to MCP tool call: %w", err)
return nil, fmt.Errorf("failed to convert action to MCP tool call: %w", err)
}
// Create sub-action result
subActionResult := &SubActionResult{
ActionName: string(action.Method),
Arguments: action.Params,
StartTime: subActionStartTime.Unix(),
}
// Execute via MCP tool
result, err := dExt.client.CallTool(ctx, req)
subActionResult.Elapsed = time.Since(subActionStartTime).Milliseconds()
if err != nil {
return fmt.Errorf("MCP tool call failed: %w", err)
subActionResult.Error = err
return []*SubActionResult{subActionResult}, fmt.Errorf("MCP tool call failed: %w", err)
}
// Check if the tool execution had business logic errors
if result.IsError {
var errMsg string
if len(result.Content) > 0 {
return fmt.Errorf("invoke tool %s failed: %v",
tool.Name(), result.Content)
errMsg = fmt.Sprintf("invoke tool %s failed: %v", tool.Name(), result.Content)
} else {
errMsg = fmt.Sprintf("invoke tool %s failed", tool.Name())
}
return fmt.Errorf("invoke tool %s failed", tool.Name())
err := errors.New(errMsg)
subActionResult.Error = err
return []*SubActionResult{subActionResult}, err
}
// Handle special AI actions (start_to_goal, ai_action) that return sub-actions
if action.Method == option.ACTION_StartToGoal || action.Method == option.ACTION_AIAction {
return dExt.parseAIActionResult(result, subActionResult)
}
// For regular actions, collect session data and return single sub-action result
subActionData := dExt.GetData(true) // reset after getting data
// Add requests if any
if requests, ok := subActionData["requests"].([]*DriverRequests); ok && len(requests) > 0 {
subActionResult.Requests = requests
}
// Add screen_results if any
if screenResults, ok := subActionData["screen_results"].([]*ScreenResult); ok && len(screenResults) > 0 {
subActionResult.ScreenResults = screenResults
}
log.Debug().Str("tool", string(tool.Name())).
Msg("execute action via MCP tool")
return nil
return []*SubActionResult{subActionResult}, nil
}
// parseAIActionResult parses the result from AI actions (start_to_goal, ai_action) and extracts sub-actions
func (dExt *XTDriver) parseAIActionResult(result *mcp.CallToolResult, originalSubAction *SubActionResult) ([]*SubActionResult, error) {
// Parse the JSON response to extract sub_actions
var responseData map[string]interface{}
if len(result.Content) > 0 {
// Get the first text content
if textContent, ok := result.Content[0].(mcp.TextContent); ok {
if err := json.Unmarshal([]byte(textContent.Text), &responseData); err != nil {
log.Warn().Err(err).Msg("failed to parse AI action result, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
} else {
log.Warn().Msg("AI action result is not text content, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
}
// Extract sub_actions from the response
if subActionsData, ok := responseData["sub_actions"]; ok {
// Convert to JSON and back to properly deserialize SubActionResult structs
subActionsJSON, err := json.Marshal(subActionsData)
if err != nil {
log.Warn().Err(err).Msg("failed to marshal sub_actions, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
var subActionResults []*SubActionResult
if err := json.Unmarshal(subActionsJSON, &subActionResults); err != nil {
log.Warn().Err(err).Msg("failed to unmarshal sub_actions, falling back to single action")
return []*SubActionResult{originalSubAction}, nil
}
log.Debug().Int("sub_actions_count", len(subActionResults)).
Str("action", string(originalSubAction.ActionName)).
Msg("parsed AI action sub-actions")
return subActionResults, nil
}
// If no sub_actions found, return the original action as a single result
log.Debug().Str("action", string(originalSubAction.ActionName)).
Msg("no sub_actions found in AI action result, using single action")
return []*SubActionResult{originalSubAction}, nil
}
// NewDeviceWithDefault is a helper function to create a device with default options