feat: implement ToolStartToGoal and fix LLM service initialization

- Add ToolStartToGoal implementation with AI-driven goal automation
- Fix LLM service not initialized issue by applying global AI config to XTDriver creation
- Ensure XTDriver is created with proper AI services from the first initialization
- Add StartToGoal method to StepMobile for goal-oriented automation
- Register ToolStartToGoal in MCP server and add corresponding action type
- Add comprehensive test case for StartToGoal functionality
- Fix ReturnSchema consistency across AI tools (StartToGoal, AIAction, Finished)
- Extract AI service options in MCP argument processing

This resolves the root cause where XTDriver was created without AI services
in runStepMobileUI, ensuring only one XTDriver initialization with complete
AI service configuration.
This commit is contained in:
lilong.129
2025-06-05 16:52:11 +08:00
parent 0add3231ff
commit c4e7ab00a7
7 changed files with 199 additions and 19 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2506051419 v5.0.0-beta-2506051652

View File

@@ -177,6 +177,18 @@ func (s *StepMobile) TapByUITypes(opts ...option.ActionOption) *StepMobile {
return s return s
} }
// StartToGoal do goal-oriented actions with VLM
func (s *StepMobile) StartToGoal(prompt string, opts ...option.ActionOption) *StepMobile {
action := option.MobileAction{
Method: option.ACTION_StartToGoal,
Params: prompt,
Options: option.NewActionOptions(opts...),
}
s.obj().Actions = append(s.obj().Actions, action)
return s
}
// AIAction do actions with VLM // AIAction do actions with VLM
func (s *StepMobile) AIAction(prompt string, opts ...option.ActionOption) *StepMobile { func (s *StepMobile) AIAction(prompt string, opts ...option.ActionOption) *StepMobile {
action := option.MobileAction{ action := option.MobileAction{
@@ -707,6 +719,29 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
Platform: mobileStep.OSType, Platform: mobileStep.OSType,
Serial: mobileStep.Serial, Serial: mobileStep.Serial,
} }
// Extract AI service options from global configuration
if s.caseRunner != nil && s.caseRunner.Config != nil {
globalConfig := s.caseRunner.Config.Get()
if globalConfig != nil {
var aiOpts []option.AIServiceOption
// Add LLM service if configured
if globalConfig.LLMService != "" {
aiOpts = append(aiOpts, option.WithLLMService(globalConfig.LLMService))
log.Debug().Str("llmService", string(globalConfig.LLMService)).Msg("Applied global LLM service to XTDriver config")
}
// Add CV service if configured
if globalConfig.CVService != "" {
aiOpts = append(aiOpts, option.WithCVService(globalConfig.CVService))
log.Debug().Str("cvService", string(globalConfig.CVService)).Msg("Applied global CV service to XTDriver config")
}
config.AIOptions = aiOpts
}
}
uiDriver, err := uixt.GetOrCreateXTDriver(config) uiDriver, err := uixt.GetOrCreateXTDriver(config)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -810,17 +845,30 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
return stepResult, err return stepResult, err
} }
// Apply global AntiRisk configuration if enabled in testcase config // Apply global configuration from testcase config
if s.caseRunner != nil && s.caseRunner.Config != nil { if s.caseRunner != nil && s.caseRunner.Config != nil {
config := s.caseRunner.Config.Get() config := s.caseRunner.Config.Get()
if config != nil && config.AntiRisk { if config != nil {
if action.Options == nil { if action.Options == nil {
action.Options = &option.ActionOptions{} action.Options = &option.ActionOptions{}
} }
// Only set AntiRisk to true if it's not already explicitly set to false
if !action.Options.AntiRisk { // Apply global AntiRisk configuration
if config.AntiRisk && !action.Options.AntiRisk {
action.Options.AntiRisk = true action.Options.AntiRisk = true
} }
// Apply global LLM service configuration for AI actions
if action.Method == option.ACTION_AIAction || action.Method == option.ACTION_StartToGoal {
if config.LLMService != "" && action.Options.LLMService == "" {
action.Options.LLMService = string(config.LLMService)
log.Debug().Str("action", string(action.Method)).Str("llmService", action.Options.LLMService).Msg("Applied global LLM service config to action")
}
if config.CVService != "" && action.Options.CVService == "" {
action.Options.CVService = string(config.CVService)
log.Debug().Str("action", string(action.Method)).Str("cvService", action.Options.CVService).Msg("Applied global CV service config to action")
}
}
} }
} }

View File

@@ -81,9 +81,50 @@ func TestAndroidAction(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestStartToGoal(t *testing.T) {
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
1. 游戏目标: 玩家需要在规定时间内,通过连接相同的图案或图标,将它们从游戏界面中消除。
2. 连接规则:
- 两个相同的图案可以通过不超过三条直线连接。
- 连接线可以水平或垂直,但不能斜线,也不能跨过其他图案。
- 连接线的转折次数不能超过两次。
3. 游戏界面:
- 游戏界面通常是一个矩形区域,内含多个图案或图标,排列成行和列。
- 图案或图标在未选中状态下背景为白色,选中状态下背景为绿色。
4. 时间限制: 游戏通常设有时间限制,玩家需要在时间耗尽前完成所有图案的消除。
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。
注意事项:
1、当连接错误时顶部的红心会减少一个需及时调整策略避免红心变为0个后游戏失败
2、不要连续 2 次点击同一个图案
3、不要犯重复的错误
请严格按照以上游戏规则,开始游戏
`
testCase := &hrp.TestCase{
Config: hrp.NewConfig("run ui action with start to goal").
SetLLMService(option.LLMServiceTypeDoubaoVL),
TestSteps: []hrp.IStep{
hrp.NewStep("启动抖音「连了又连」小游戏").
Android().
StartToGoal("启动抖音,搜索「连了又连」小游戏,并启动游戏").
Validate().
AssertAI("当前位于抖音「连了又连」小游戏页面"),
hrp.NewStep("开始游戏").
Android().
StartToGoal(userInstruction, option.WithMaxRetryTimes(100)),
},
}
err := hrp.NewRunner(t).Run(testCase)
assert.Nil(t, err)
}
func TestAIAction(t *testing.T) { func TestAIAction(t *testing.T) {
testCase := &hrp.TestCase{ testCase := &hrp.TestCase{
Config: hrp.NewConfig("run ui action with ai"), Config: hrp.NewConfig("run ui action with ai").
SetLLMService(option.LLMServiceTypeDoubaoVL),
TestSteps: []hrp.IStep{ TestSteps: []hrp.IStep{
hrp.NewStep("launch settings"). hrp.NewStep("launch settings").
Android().AIAction("进入手机系统设置"). Android().AIAction("进入手机系统设置").

View File

@@ -279,9 +279,23 @@ func setupXTDriver(_ context.Context, args map[string]any) (*XTDriver, error) {
platform, _ := args["platform"].(string) platform, _ := args["platform"].(string)
serial, _ := args["serial"].(string) serial, _ := args["serial"].(string)
// Extract AI service options from arguments if provided
var aiOpts []option.AIServiceOption
// Check for LLM service type
if llmService, ok := args["llm_service"].(string); ok && llmService != "" {
aiOpts = append(aiOpts, option.WithLLMService(option.LLMServiceType(llmService)))
}
// Check for CV service type
if cvService, ok := args["cv_service"].(string); ok && cvService != "" {
aiOpts = append(aiOpts, option.WithCVService(option.CVServiceType(cvService)))
}
config := DriverCacheConfig{ config := DriverCacheConfig{
Platform: platform, Platform: platform,
Serial: serial, Serial: serial,
AIOptions: aiOpts,
} }
return GetOrCreateXTDriver(config) return GetOrCreateXTDriver(config)
} }

View File

@@ -121,6 +121,7 @@ func (s *MCPServer4XTDriver) registerTools() {
s.registerTool(&ToolWebCloseTab{}) s.registerTool(&ToolWebCloseTab{})
// AI Tools // AI Tools
s.registerTool(&ToolStartToGoal{})
s.registerTool(&ToolAIAction{}) s.registerTool(&ToolAIAction{})
s.registerTool(&ToolFinished{}) s.registerTool(&ToolFinished{})
} }
@@ -214,6 +215,14 @@ func extractActionOptionsToArguments(actionOptions []option.ActionOption, argume
if tempOptions.PressDuration > 0 { if tempOptions.PressDuration > 0 {
arguments["press_duration"] = tempOptions.PressDuration arguments["press_duration"] = tempOptions.PressDuration
} }
// Add AI service options
if tempOptions.LLMService != "" {
arguments["llm_service"] = tempOptions.LLMService
}
if tempOptions.CVService != "" {
arguments["cv_service"] = tempOptions.CVService
}
} }
func getFloat64ValueOrDefault(value float64, defaultValue float64) float64 { func getFloat64ValueOrDefault(value float64, defaultValue float64) float64 {

View File

@@ -10,6 +10,65 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// ToolStartToGoal implements the start_to_goal tool call.
type ToolStartToGoal struct{}
func (t *ToolStartToGoal) Name() option.ActionName {
return option.ACTION_StartToGoal
}
func (t *ToolStartToGoal) Description() string {
return "Start AI-driven automation to achieve a specific goal using natural language description"
}
func (t *ToolStartToGoal) Options() []mcp.ToolOption {
unifiedReq := &option.ActionOptions{}
return unifiedReq.GetMCPOptions(option.ACTION_StartToGoal)
}
func (t *ToolStartToGoal) Implement() server.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
driverExt, err := setupXTDriver(ctx, request.Params.Arguments)
if err != nil {
return nil, fmt.Errorf("setup driver failed: %w", err)
}
unifiedReq, err := parseActionOptions(request.Params.Arguments)
if err != nil {
return nil, err
}
// Start to goal logic
log.Info().Str("prompt", unifiedReq.Prompt).Msg("starting to goal")
err = driverExt.StartToGoal(unifiedReq.Prompt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to achieve goal: %s", err.Error())), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Successfully achieved goal: %s", unifiedReq.Prompt)), nil
}
}
func (t *ToolStartToGoal) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) {
if prompt, ok := action.Params.(string); ok {
arguments := map[string]any{
"prompt": prompt,
}
// Extract options to arguments
extractActionOptionsToArguments(action.GetOptions(), arguments)
return buildMCPCallToolRequest(t.Name(), arguments), nil
}
return mcp.CallToolRequest{}, fmt.Errorf("invalid start to goal params: %v", action.Params)
}
func (t *ToolStartToGoal) ReturnSchema() map[string]string {
return map[string]string{
"message": "string: Success message confirming goal was achieved, or error message if failed",
}
}
// ToolAIAction implements the ai_action tool call. // ToolAIAction implements the ai_action tool call.
type ToolAIAction struct{} type ToolAIAction struct{}
@@ -54,6 +113,10 @@ func (t *ToolAIAction) ConvertActionToCallToolRequest(action option.MobileAction
arguments := map[string]any{ arguments := map[string]any{
"prompt": prompt, "prompt": prompt,
} }
// Extract options to arguments
extractActionOptionsToArguments(action.GetOptions(), arguments)
return buildMCPCallToolRequest(t.Name(), arguments), nil return buildMCPCallToolRequest(t.Name(), arguments), nil
} }
return mcp.CallToolRequest{}, fmt.Errorf("invalid AI action params: %v", action.Params) return mcp.CallToolRequest{}, fmt.Errorf("invalid AI action params: %v", action.Params)
@@ -61,9 +124,7 @@ func (t *ToolAIAction) ConvertActionToCallToolRequest(action option.MobileAction
func (t *ToolAIAction) ReturnSchema() map[string]string { func (t *ToolAIAction) ReturnSchema() map[string]string {
return map[string]string{ return map[string]string{
"message": "string: Success message confirming AI action was performed", "message": "string: Success message confirming AI action was performed, or error message if failed",
"prompt": "string: Natural language prompt that was processed",
"actionTaken": "string: Description of the specific action that was taken by AI",
} }
} }
@@ -107,8 +168,6 @@ func (t *ToolFinished) ConvertActionToCallToolRequest(action option.MobileAction
func (t *ToolFinished) ReturnSchema() map[string]string { func (t *ToolFinished) ReturnSchema() map[string]string {
return map[string]string{ return map[string]string{
"message": "string: Success message confirming task completion", "message": "string: Success message confirming task completion, or error message if failed",
"content": "string: Completion reason or result description",
"taskCompleted": "bool: Boolean indicating task was successfully finished",
} }
} }

View File

@@ -73,7 +73,6 @@ const (
ACTION_KeyCode ActionName = "keycode" ACTION_KeyCode ActionName = "keycode"
ACTION_Delete ActionName = "delete" // delete action ACTION_Delete ActionName = "delete" // delete action
ACTION_Backspace ActionName = "backspace" // backspace action ACTION_Backspace ActionName = "backspace" // backspace action
ACTION_AIAction ActionName = "ai_action" // action with ai
ACTION_TapBySelector ActionName = "tap_by_selector" ACTION_TapBySelector ActionName = "tap_by_selector"
ACTION_HoverBySelector ActionName = "hover_by_selector" ACTION_HoverBySelector ActionName = "hover_by_selector"
ACTION_Hover ActionName = "hover" // generic hover action ACTION_Hover ActionName = "hover" // generic hover action
@@ -101,9 +100,13 @@ const (
ACTION_InstallApp ActionName = "install_app" ACTION_InstallApp ActionName = "install_app"
ACTION_UninstallApp ActionName = "uninstall_app" ACTION_UninstallApp ActionName = "uninstall_app"
ACTION_DownloadApp ActionName = "download_app" ACTION_DownloadApp ActionName = "download_app"
ACTION_Finished ActionName = "finished"
ACTION_CallFunction ActionName = "call_function" ACTION_CallFunction ActionName = "call_function"
// AI actions
ACTION_StartToGoal ActionName = "start_to_goal" // start to goal action
ACTION_AIAction ActionName = "ai_action" // action with ai
ACTION_Finished ActionName = "finished" // finished action
// anti-risk actions // anti-risk actions
ACTION_SetTouchInfo ActionName = "set_touch_info" ACTION_SetTouchInfo ActionName = "set_touch_info"
ACTION_SetTouchInfoList ActionName = "set_touch_info_list" ACTION_SetTouchInfoList ActionName = "set_touch_info_list"
@@ -178,8 +181,10 @@ type ActionOptions struct {
Params []float64 `json:"params,omitempty" yaml:"params,omitempty" desc:"Generic parameter array"` Params []float64 `json:"params,omitempty" yaml:"params,omitempty" desc:"Generic parameter array"`
// AI related // AI related
Prompt string `json:"prompt,omitempty" yaml:"prompt,omitempty" desc:"AI action prompt"` Prompt string `json:"prompt,omitempty" yaml:"prompt,omitempty" desc:"AI action prompt"`
Content string `json:"content,omitempty" yaml:"content,omitempty" desc:"Content for finished action"` Content string `json:"content,omitempty" yaml:"content,omitempty" desc:"Content for finished action"`
LLMService string `json:"llm_service,omitempty" yaml:"llm_service,omitempty" desc:"LLM service type for AI actions"`
CVService string `json:"cv_service,omitempty" yaml:"cv_service,omitempty" desc:"Computer vision service type for AI actions"`
// Time related // Time related
Seconds float64 `json:"seconds,omitempty" yaml:"seconds,omitempty" desc:"Sleep duration in seconds"` Seconds float64 `json:"seconds,omitempty" yaml:"seconds,omitempty" desc:"Sleep duration in seconds"`
@@ -679,6 +684,9 @@ func (o *ActionOptions) validateActionSpecificFields(actionType ActionName) erro
ACTION_AIAction: func() error { ACTION_AIAction: func() error {
return o.requireFields("prompt", o.Prompt != "") return o.requireFields("prompt", o.Prompt != "")
}, },
ACTION_StartToGoal: func() error {
return o.requireFields("prompt", o.Prompt != "")
},
ACTION_Finished: func() error { ACTION_Finished: func() error {
return o.requireFields("content", o.Content != "") return o.requireFields("content", o.Content != "")
}, },
@@ -750,7 +758,8 @@ func (o *ActionOptions) GetMCPOptions(actionType ActionName) []mcp.ToolOption {
ACTION_Sleep: {"seconds"}, ACTION_Sleep: {"seconds"},
ACTION_SleepMS: {"platform", "serial", "milliseconds"}, ACTION_SleepMS: {"platform", "serial", "milliseconds"},
ACTION_SleepRandom: {"platform", "serial", "params"}, ACTION_SleepRandom: {"platform", "serial", "params"},
ACTION_AIAction: {"platform", "serial", "prompt"}, ACTION_AIAction: {"platform", "serial", "prompt", "llm_service", "cv_service"},
ACTION_StartToGoal: {"platform", "serial", "prompt", "llm_service", "cv_service"},
ACTION_Finished: {"content"}, ACTION_Finished: {"content"},
ACTION_ListAvailableDevices: {}, ACTION_ListAvailableDevices: {},
ACTION_SelectDevice: {"platform", "serial"}, ACTION_SelectDevice: {"platform", "serial"},