diff --git a/internal/version/VERSION b/internal/version/VERSION index 521a4254..2ecd00d5 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2505271534 +v5.0.0-beta-2505271946 diff --git a/mcphost/host.go b/mcphost/host.go index 245f3ca7..286bc28a 100644 --- a/mcphost/host.go +++ b/mcphost/host.go @@ -26,7 +26,6 @@ type MCPHost struct { connections map[string]*Connection config *MCPConfig withUIXT bool - drivers map[string]*uixt.XTDriver } // Connection represents a connection to an MCP server @@ -52,7 +51,6 @@ func NewMCPHost(configPath string, withUIXT bool) (*MCPHost, error) { host := &MCPHost{ connections: make(map[string]*Connection), config: config, - drivers: make(map[string]*uixt.XTDriver), withUIXT: withUIXT, } @@ -175,6 +173,18 @@ func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) { return conn.Client, nil } +// GetAllClients returns all MCP clients +func (h *MCPHost) GetAllClients() map[string]client.MCPClient { + h.mu.RLock() + defer h.mu.RUnlock() + + clients := make(map[string]client.MCPClient) + for name, conn := range h.connections { + clients[name] = conn.Client + } + return clients +} + // GetTools returns all tools from all MCP servers func (h *MCPHost) GetTools(ctx context.Context) []MCPTools { h.mu.RLock() @@ -204,28 +214,20 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc h.mu.RLock() defer h.mu.RUnlock() - // Get all tools - results := h.GetTools(ctx) - - // Find the server's tools - var serverTools MCPTools - found := false - for _, tools := range results { - if tools.ServerName == serverName { - serverTools = tools - found = true - break - } - } - if !found { + // Get connection for the server + conn, exists := h.connections[serverName] + if !exists { return nil, fmt.Errorf("no connection found for MCP server %s", serverName) } - if serverTools.Err != nil { - return nil, serverTools.Err + + // Get tools from the specific server + listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to get tools from server %s: %w", serverName, err) } // Find the specific tool - for _, tool := range serverTools.Tools { + for _, tool := range listResults.Tools { if tool.Name == toolName { return &tool, nil } diff --git a/runner.go b/runner.go index 8f86f184..cafc28d9 100644 --- a/runner.go +++ b/runner.go @@ -495,10 +495,19 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) { // init XTDriver and register to unified cache for _, driverConfig := range driverConfigs { - _, err := uixt.GetOrCreateXTDriver(driverConfig) + driver, err := uixt.GetOrCreateXTDriver(driverConfig) if err != nil { return nil, errors.Wrapf(err, "init %s XTDriver failed", driverConfig.Platform) } + + // Set MCP clients if MCPHost is available + if r.parser.MCPHost != nil { + mcpClients := r.parser.MCPHost.GetAllClients() + driver.SetMCPClients(mcpClients) + log.Debug().Str("serial", driverConfig.Serial). + Int("mcp_clients", len(mcpClients)). + Msg("Set MCP clients for XTDriver") + } } return parsedConfig, nil diff --git a/server/uixt.go b/server/uixt.go index 71f11f0a..7722f540 100644 --- a/server/uixt.go +++ b/server/uixt.go @@ -19,7 +19,7 @@ func (r *Router) uixtActionHandler(c *gin.Context) { return } - if err = dExt.ExecuteAction(req); err != nil { + if err = dExt.ExecuteAction(c.Request.Context(), req); err != nil { log.Err(err).Interface("action", req). Msg("exec uixt action failed") RenderError(c, err) @@ -42,7 +42,7 @@ func (r *Router) uixtActionsHandler(c *gin.Context) { } for _, action := range actions { - if err = dExt.ExecuteAction(action); err != nil { + if err = dExt.ExecuteAction(c.Request.Context(), action); err != nil { log.Err(err).Interface("action", action). Msg("exec uixt action failed") RenderError(c, err) diff --git a/step_ui.go b/step_ui.go index 3b776ef6..ad13749b 100644 --- a/step_ui.go +++ b/step_ui.go @@ -1,6 +1,7 @@ package hrp import ( + "context" "fmt" "strings" "time" @@ -803,7 +804,7 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err continue } - err = uiDriver.ExecuteAction(action) + err = uiDriver.ExecuteAction(context.Background(), action) actionResult.Elapsed = time.Since(actionStartTime).Milliseconds() stepResult.Actions = append(stepResult.Actions, actionResult) if err != nil { diff --git a/uixt/driver_ext_ai.go b/uixt/driver_ext_ai.go index 169b65b0..30910a20 100644 --- a/uixt/driver_ext_ai.go +++ b/uixt/driver_ext_ai.go @@ -62,7 +62,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error { }, } - _, err = dExt.Client.CallTool(context.Background(), req) + _, err = dExt.client.CallTool(context.Background(), req) if err != nil { return err } diff --git a/uixt/driver_handler.go b/uixt/driver_handler.go index 86acb3e7..0abd7d87 100644 --- a/uixt/driver_handler.go +++ b/uixt/driver_handler.go @@ -1,6 +1,7 @@ package uixt import ( + "context" "fmt" "path/filepath" "time" @@ -8,6 +9,7 @@ import ( "github.com/httprunner/httprunner/v5/internal/builtin" "github.com/httprunner/httprunner/v5/internal/config" "github.com/httprunner/httprunner/v5/uixt/option" + "github.com/mark3labs/mcp-go/mcp" "github.com/rs/zerolog/log" ) @@ -47,6 +49,14 @@ func (dExt *XTDriver) Call(desc string, fn func(), opts ...option.ActionOption) func preHandler_TapAbsXY(driver IDriver, options *option.ActionOptions, rawX, rawY float64) ( x, y float64, err error) { + // Call MCP action tool if anti-risk is enabled + if options.AntiRisk { + callMCPActionTool(driver, option.ACTION_TapAbsXY, map[string]any{ + "x": rawX, + "y": rawY, + }) + } + x, y = options.ApplyTapOffset(rawX, rawY) // mark UI operation @@ -143,3 +153,131 @@ func postHandler(driver IDriver, actionType option.ActionName, options *option.A } return nil } + +// callMCPActionTool calls MCP tool for the given action +func callMCPActionTool(driver IDriver, actionType option.ActionName, arguments map[string]any) { + // Get XTDriver from cache + dExt := getXTDriverFromCache(driver) + if dExt == nil { + return + } + + // Define action to MCP server mapping for pre-hooks + serverMapping := getPreHookServerMapping(actionType) + if serverMapping == nil { + return // No MCP hook configured for this action + } + + callMCPTool(dExt, serverMapping.ServerName, serverMapping.ToolName, arguments, actionType) +} + +// MCPServerMapping defines the mapping between action and MCP server/tool +type MCPServerMapping struct { + ServerName string + ToolName string +} + +// getPreHookServerMapping returns MCP server mapping for pre-hooks +// TODO: You can customize these mappings according to your needs +func getPreHookServerMapping(actionType option.ActionName) *MCPServerMapping { + mappings := map[option.ActionName]*MCPServerMapping{ + option.ACTION_TapAbsXY: { + ServerName: "evalpkgs", + ToolName: "log_pre_action", + }, + // Add more mappings as needed + // option.ACTION_Swipe: { + // ServerName: "monitor", + // ToolName: "start_timer", + // }, + } + return mappings[actionType] +} + +// getXTDriverFromCache gets XTDriver from cache using device UUID +func getXTDriverFromCache(driver IDriver) *XTDriver { + // Get device info to find the corresponding XTDriver + device := driver.GetDevice() + if device == nil { + log.Warn().Msg("Cannot get device from driver for MCP hook") + return nil + } + + // Get device UUID (serial/udid/connectKey/browserID) + deviceUUID := device.UUID() + if deviceUUID == "" { + log.Warn().Msg("Cannot get device UUID for MCP hook") + return nil + } + + // Get XTDriver from cache using device UUID as serial + cachedDrivers := ListCachedDrivers() + for _, cached := range cachedDrivers { + if cached.Serial == deviceUUID { + return cached.Driver + } + } + + log.Warn().Str("uuid", deviceUUID). + Msg("Cannot find cached XTDriver for MCP hook") + return nil +} + +// callMCPTool calls the specified MCP tool +func callMCPTool(dExt *XTDriver, serverName, toolName string, arguments map[string]any, actionType option.ActionName) { + // Get MCP client + mcpClient, exists := dExt.GetMCPClient(serverName) + if !exists { + log.Debug().Str("server", serverName).Msg("MCP server not found for hook") + return + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Prepare arguments + if arguments == nil { + arguments = make(map[string]any) + } + // Add action type and hook type to arguments + arguments["action_type"] = string(actionType) + + // Call MCP tool + req := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Name: toolName, + Arguments: arguments, + }, + } + + result, err := mcpClient.CallTool(ctx, req) + if err != nil { + log.Debug().Err(err). + Str("server", serverName). + Str("tool", toolName). + Msg("MCP hook call failed") + return + } + + if result.IsError { + log.Debug(). + Str("server", serverName). + Str("tool", toolName). + Interface("content", result.Content). + Msg("MCP hook returned error") + return + } + + log.Debug(). + Str("server", serverName). + Str("tool", toolName). + Str("action", string(actionType)). + Msg("MCP hook called successfully") +} diff --git a/uixt/mcp_server.go b/uixt/mcp_server.go index 4a4d96ff..3929f9a4 100644 --- a/uixt/mcp_server.go +++ b/uixt/mcp_server.go @@ -1071,10 +1071,10 @@ func (t *ToolSwipeCoordinate) Implement() server.ToolHandlerFunc { params := []float64{unifiedReq.FromX, unifiedReq.FromY, unifiedReq.ToX, unifiedReq.ToY} opts := []option.ActionOption{} - if unifiedReq.Duration > 0 && unifiedReq.Duration > 0 { + if unifiedReq.Duration > 0 { opts = append(opts, option.WithDuration(unifiedReq.Duration)) } - if unifiedReq.PressDuration > 0 && unifiedReq.PressDuration > 0 { + if unifiedReq.PressDuration > 0 { opts = append(opts, option.WithPressDuration(unifiedReq.PressDuration)) } @@ -1146,10 +1146,10 @@ func (t *ToolSwipeToTapApp) Implement() server.ToolHandlerFunc { } // Add numeric options - if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 { + if unifiedReq.MaxRetryTimes > 0 { opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes)) } - if unifiedReq.Index > 0 && unifiedReq.Index > 0 { + if unifiedReq.Index > 0 { opts = append(opts, option.WithIndex(unifiedReq.Index)) } @@ -1218,10 +1218,10 @@ func (t *ToolSwipeToTapText) Implement() server.ToolHandlerFunc { } // Add numeric options - if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 { + if unifiedReq.MaxRetryTimes > 0 { opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes)) } - if unifiedReq.Index > 0 && unifiedReq.Index > 0 { + if unifiedReq.Index > 0 { opts = append(opts, option.WithIndex(unifiedReq.Index)) } @@ -1290,10 +1290,10 @@ func (t *ToolSwipeToTapTexts) Implement() server.ToolHandlerFunc { } // Add numeric options - if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 { + if unifiedReq.MaxRetryTimes > 0 { opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes)) } - if unifiedReq.Index > 0 && unifiedReq.Index > 0 { + if unifiedReq.Index > 0 { opts = append(opts, option.WithIndex(unifiedReq.Index)) } diff --git a/uixt/option/action.go b/uixt/option/action.go index 108132c3..f145261a 100644 --- a/uixt/option/action.go +++ b/uixt/option/action.go @@ -175,6 +175,9 @@ type ActionOptions struct { ScreenOptions + // Anti-risk options + AntiRisk bool `json:"anti_risk,omitempty" yaml:"anti_risk,omitempty" desc:"Enable anti-risk MCP tool calls"` + // Custom options Custom map[string]interface{} `json:"custom,omitempty" yaml:"custom,omitempty" desc:"Custom options"` } @@ -286,6 +289,10 @@ func (o *ActionOptions) Options() []ActionOption { options = append(options, WithMatchOne(true)) } + if o.AntiRisk { + options = append(options, WithAntiRisk(true)) + } + // custom options if o.Custom != nil { for k, v := range o.Custom { @@ -494,6 +501,12 @@ func WithIgnoreNotFoundError(ignoreError bool) ActionOption { } } +func WithAntiRisk(antiRisk bool) ActionOption { + return func(o *ActionOptions) { + o.AntiRisk = antiRisk + } +} + // HTTP API direct usage methods // ValidateForHTTPAPI validates the request for HTTP API usage diff --git a/uixt/sdk.go b/uixt/sdk.go index d55c26d5..4ce4b05d 100644 --- a/uixt/sdk.go +++ b/uixt/sdk.go @@ -15,9 +15,10 @@ import ( func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, error) { driverExt := &XTDriver{ IDriver: driver, - Client: &MCPClient4XTDriver{ + client: &MCPClient4XTDriver{ Server: NewMCPServer(), }, + loadedMCPClients: make(map[string]client.MCPClient), } services := option.NewAIServiceOptions(opts...) @@ -47,7 +48,8 @@ type XTDriver struct { CVService ai.ICVService // OCR/CV LLMService ai.ILLMService // LLM - Client *MCPClient4XTDriver // MCP Client + client *MCPClient4XTDriver // MCP Client for built-in uixt server + loadedMCPClients map[string]client.MCPClient // External MCP clients } // MCPClient4XTDriver is a mock MCP client that only implements the methods used by the host @@ -80,9 +82,9 @@ func (c *MCPClient4XTDriver) Close() error { return nil } -func (dExt *XTDriver) ExecuteAction(action MobileAction) (err error) { +func (dExt *XTDriver) ExecuteAction(ctx context.Context, action MobileAction) (err error) { // Find the corresponding tool for this action method - tool := dExt.Client.Server.GetToolByAction(action.Method) + tool := dExt.client.Server.GetToolByAction(action.Method) if tool == nil { return fmt.Errorf("no tool found for action method: %s", action.Method) } @@ -94,7 +96,7 @@ func (dExt *XTDriver) ExecuteAction(action MobileAction) (err error) { } // Execute via MCP tool - result, err := dExt.Client.CallTool(context.Background(), req) + result, err := dExt.client.CallTool(ctx, req) if err != nil { return fmt.Errorf("MCP tool call failed: %w", err) } @@ -139,3 +141,22 @@ func NewDeviceWithDefault(platform, serial string) (device IDevice, err error) { return device, err } + +// SetMCPClients sets the external MCP clients for the driver +func (dExt *XTDriver) SetMCPClients(clients map[string]client.MCPClient) { + if dExt.loadedMCPClients == nil { + dExt.loadedMCPClients = make(map[string]client.MCPClient) + } + for name, client := range clients { + dExt.loadedMCPClients[name] = client + } +} + +// GetMCPClient returns the MCP client for the specified server name +func (dExt *XTDriver) GetMCPClient(serverName string) (client.MCPClient, bool) { + if dExt.loadedMCPClients == nil { + return nil, false + } + client, exists := dExt.loadedMCPClients[serverName] + return client, exists +}