feat: implement MCP hooks integration with anti_risk option

This commit is contained in:
lilong.129
2025-05-27 19:46:08 +08:00
parent f4cc74b3ca
commit 866cc0e4d2
10 changed files with 222 additions and 38 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2505271534
v5.0.0-beta-2505271946

View File

@@ -26,7 +26,6 @@ type MCPHost struct {
connections map[string]*Connection
config *MCPConfig
withUIXT bool
drivers map[string]*uixt.XTDriver
}
// Connection represents a connection to an MCP server
@@ -52,7 +51,6 @@ func NewMCPHost(configPath string, withUIXT bool) (*MCPHost, error) {
host := &MCPHost{
connections: make(map[string]*Connection),
config: config,
drivers: make(map[string]*uixt.XTDriver),
withUIXT: withUIXT,
}
@@ -175,6 +173,18 @@ func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
return conn.Client, nil
}
// GetAllClients returns all MCP clients
func (h *MCPHost) GetAllClients() map[string]client.MCPClient {
h.mu.RLock()
defer h.mu.RUnlock()
clients := make(map[string]client.MCPClient)
for name, conn := range h.connections {
clients[name] = conn.Client
}
return clients
}
// GetTools returns all tools from all MCP servers
func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
h.mu.RLock()
@@ -204,28 +214,20 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc
h.mu.RLock()
defer h.mu.RUnlock()
// Get all tools
results := h.GetTools(ctx)
// Find the server's tools
var serverTools MCPTools
found := false
for _, tools := range results {
if tools.ServerName == serverName {
serverTools = tools
found = true
break
}
}
if !found {
// Get connection for the server
conn, exists := h.connections[serverName]
if !exists {
return nil, fmt.Errorf("no connection found for MCP server %s", serverName)
}
if serverTools.Err != nil {
return nil, serverTools.Err
// Get tools from the specific server
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
return nil, fmt.Errorf("failed to get tools from server %s: %w", serverName, err)
}
// Find the specific tool
for _, tool := range serverTools.Tools {
for _, tool := range listResults.Tools {
if tool.Name == toolName {
return &tool, nil
}

View File

@@ -495,10 +495,19 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) {
// init XTDriver and register to unified cache
for _, driverConfig := range driverConfigs {
_, err := uixt.GetOrCreateXTDriver(driverConfig)
driver, err := uixt.GetOrCreateXTDriver(driverConfig)
if err != nil {
return nil, errors.Wrapf(err, "init %s XTDriver failed", driverConfig.Platform)
}
// Set MCP clients if MCPHost is available
if r.parser.MCPHost != nil {
mcpClients := r.parser.MCPHost.GetAllClients()
driver.SetMCPClients(mcpClients)
log.Debug().Str("serial", driverConfig.Serial).
Int("mcp_clients", len(mcpClients)).
Msg("Set MCP clients for XTDriver")
}
}
return parsedConfig, nil

View File

@@ -19,7 +19,7 @@ func (r *Router) uixtActionHandler(c *gin.Context) {
return
}
if err = dExt.ExecuteAction(req); err != nil {
if err = dExt.ExecuteAction(c.Request.Context(), req); err != nil {
log.Err(err).Interface("action", req).
Msg("exec uixt action failed")
RenderError(c, err)
@@ -42,7 +42,7 @@ func (r *Router) uixtActionsHandler(c *gin.Context) {
}
for _, action := range actions {
if err = dExt.ExecuteAction(action); err != nil {
if err = dExt.ExecuteAction(c.Request.Context(), action); err != nil {
log.Err(err).Interface("action", action).
Msg("exec uixt action failed")
RenderError(c, err)

View File

@@ -1,6 +1,7 @@
package hrp
import (
"context"
"fmt"
"strings"
"time"
@@ -803,7 +804,7 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
continue
}
err = uiDriver.ExecuteAction(action)
err = uiDriver.ExecuteAction(context.Background(), action)
actionResult.Elapsed = time.Since(actionStartTime).Milliseconds()
stepResult.Actions = append(stepResult.Actions, actionResult)
if err != nil {

View File

@@ -62,7 +62,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
},
}
_, err = dExt.Client.CallTool(context.Background(), req)
_, err = dExt.client.CallTool(context.Background(), req)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package uixt
import (
"context"
"fmt"
"path/filepath"
"time"
@@ -8,6 +9,7 @@ import (
"github.com/httprunner/httprunner/v5/internal/builtin"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/mark3labs/mcp-go/mcp"
"github.com/rs/zerolog/log"
)
@@ -47,6 +49,14 @@ func (dExt *XTDriver) Call(desc string, fn func(), opts ...option.ActionOption)
func preHandler_TapAbsXY(driver IDriver, options *option.ActionOptions, rawX, rawY float64) (
x, y float64, err error) {
// Call MCP action tool if anti-risk is enabled
if options.AntiRisk {
callMCPActionTool(driver, option.ACTION_TapAbsXY, map[string]any{
"x": rawX,
"y": rawY,
})
}
x, y = options.ApplyTapOffset(rawX, rawY)
// mark UI operation
@@ -143,3 +153,131 @@ func postHandler(driver IDriver, actionType option.ActionName, options *option.A
}
return nil
}
// callMCPActionTool calls MCP tool for the given action
func callMCPActionTool(driver IDriver, actionType option.ActionName, arguments map[string]any) {
// Get XTDriver from cache
dExt := getXTDriverFromCache(driver)
if dExt == nil {
return
}
// Define action to MCP server mapping for pre-hooks
serverMapping := getPreHookServerMapping(actionType)
if serverMapping == nil {
return // No MCP hook configured for this action
}
callMCPTool(dExt, serverMapping.ServerName, serverMapping.ToolName, arguments, actionType)
}
// MCPServerMapping defines the mapping between action and MCP server/tool
type MCPServerMapping struct {
ServerName string
ToolName string
}
// getPreHookServerMapping returns MCP server mapping for pre-hooks
// TODO: You can customize these mappings according to your needs
func getPreHookServerMapping(actionType option.ActionName) *MCPServerMapping {
mappings := map[option.ActionName]*MCPServerMapping{
option.ACTION_TapAbsXY: {
ServerName: "evalpkgs",
ToolName: "log_pre_action",
},
// Add more mappings as needed
// option.ACTION_Swipe: {
// ServerName: "monitor",
// ToolName: "start_timer",
// },
}
return mappings[actionType]
}
// getXTDriverFromCache gets XTDriver from cache using device UUID
func getXTDriverFromCache(driver IDriver) *XTDriver {
// Get device info to find the corresponding XTDriver
device := driver.GetDevice()
if device == nil {
log.Warn().Msg("Cannot get device from driver for MCP hook")
return nil
}
// Get device UUID (serial/udid/connectKey/browserID)
deviceUUID := device.UUID()
if deviceUUID == "" {
log.Warn().Msg("Cannot get device UUID for MCP hook")
return nil
}
// Get XTDriver from cache using device UUID as serial
cachedDrivers := ListCachedDrivers()
for _, cached := range cachedDrivers {
if cached.Serial == deviceUUID {
return cached.Driver
}
}
log.Warn().Str("uuid", deviceUUID).
Msg("Cannot find cached XTDriver for MCP hook")
return nil
}
// callMCPTool calls the specified MCP tool
func callMCPTool(dExt *XTDriver, serverName, toolName string, arguments map[string]any, actionType option.ActionName) {
// Get MCP client
mcpClient, exists := dExt.GetMCPClient(serverName)
if !exists {
log.Debug().Str("server", serverName).Msg("MCP server not found for hook")
return
}
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Prepare arguments
if arguments == nil {
arguments = make(map[string]any)
}
// Add action type and hook type to arguments
arguments["action_type"] = string(actionType)
// Call MCP tool
req := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Meta *struct {
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
}{
Name: toolName,
Arguments: arguments,
},
}
result, err := mcpClient.CallTool(ctx, req)
if err != nil {
log.Debug().Err(err).
Str("server", serverName).
Str("tool", toolName).
Msg("MCP hook call failed")
return
}
if result.IsError {
log.Debug().
Str("server", serverName).
Str("tool", toolName).
Interface("content", result.Content).
Msg("MCP hook returned error")
return
}
log.Debug().
Str("server", serverName).
Str("tool", toolName).
Str("action", string(actionType)).
Msg("MCP hook called successfully")
}

View File

@@ -1071,10 +1071,10 @@ func (t *ToolSwipeCoordinate) Implement() server.ToolHandlerFunc {
params := []float64{unifiedReq.FromX, unifiedReq.FromY, unifiedReq.ToX, unifiedReq.ToY}
opts := []option.ActionOption{}
if unifiedReq.Duration > 0 && unifiedReq.Duration > 0 {
if unifiedReq.Duration > 0 {
opts = append(opts, option.WithDuration(unifiedReq.Duration))
}
if unifiedReq.PressDuration > 0 && unifiedReq.PressDuration > 0 {
if unifiedReq.PressDuration > 0 {
opts = append(opts, option.WithPressDuration(unifiedReq.PressDuration))
}
@@ -1146,10 +1146,10 @@ func (t *ToolSwipeToTapApp) Implement() server.ToolHandlerFunc {
}
// Add numeric options
if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 {
if unifiedReq.MaxRetryTimes > 0 {
opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes))
}
if unifiedReq.Index > 0 && unifiedReq.Index > 0 {
if unifiedReq.Index > 0 {
opts = append(opts, option.WithIndex(unifiedReq.Index))
}
@@ -1218,10 +1218,10 @@ func (t *ToolSwipeToTapText) Implement() server.ToolHandlerFunc {
}
// Add numeric options
if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 {
if unifiedReq.MaxRetryTimes > 0 {
opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes))
}
if unifiedReq.Index > 0 && unifiedReq.Index > 0 {
if unifiedReq.Index > 0 {
opts = append(opts, option.WithIndex(unifiedReq.Index))
}
@@ -1290,10 +1290,10 @@ func (t *ToolSwipeToTapTexts) Implement() server.ToolHandlerFunc {
}
// Add numeric options
if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 {
if unifiedReq.MaxRetryTimes > 0 {
opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes))
}
if unifiedReq.Index > 0 && unifiedReq.Index > 0 {
if unifiedReq.Index > 0 {
opts = append(opts, option.WithIndex(unifiedReq.Index))
}

View File

@@ -175,6 +175,9 @@ type ActionOptions struct {
ScreenOptions
// Anti-risk options
AntiRisk bool `json:"anti_risk,omitempty" yaml:"anti_risk,omitempty" desc:"Enable anti-risk MCP tool calls"`
// Custom options
Custom map[string]interface{} `json:"custom,omitempty" yaml:"custom,omitempty" desc:"Custom options"`
}
@@ -286,6 +289,10 @@ func (o *ActionOptions) Options() []ActionOption {
options = append(options, WithMatchOne(true))
}
if o.AntiRisk {
options = append(options, WithAntiRisk(true))
}
// custom options
if o.Custom != nil {
for k, v := range o.Custom {
@@ -494,6 +501,12 @@ func WithIgnoreNotFoundError(ignoreError bool) ActionOption {
}
}
func WithAntiRisk(antiRisk bool) ActionOption {
return func(o *ActionOptions) {
o.AntiRisk = antiRisk
}
}
// HTTP API direct usage methods
// ValidateForHTTPAPI validates the request for HTTP API usage

View File

@@ -15,9 +15,10 @@ import (
func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, error) {
driverExt := &XTDriver{
IDriver: driver,
Client: &MCPClient4XTDriver{
client: &MCPClient4XTDriver{
Server: NewMCPServer(),
},
loadedMCPClients: make(map[string]client.MCPClient),
}
services := option.NewAIServiceOptions(opts...)
@@ -47,7 +48,8 @@ type XTDriver struct {
CVService ai.ICVService // OCR/CV
LLMService ai.ILLMService // LLM
Client *MCPClient4XTDriver // MCP Client
client *MCPClient4XTDriver // MCP Client for built-in uixt server
loadedMCPClients map[string]client.MCPClient // External MCP clients
}
// MCPClient4XTDriver is a mock MCP client that only implements the methods used by the host
@@ -80,9 +82,9 @@ func (c *MCPClient4XTDriver) Close() error {
return nil
}
func (dExt *XTDriver) ExecuteAction(action MobileAction) (err error) {
func (dExt *XTDriver) ExecuteAction(ctx context.Context, action MobileAction) (err error) {
// Find the corresponding tool for this action method
tool := dExt.Client.Server.GetToolByAction(action.Method)
tool := dExt.client.Server.GetToolByAction(action.Method)
if tool == nil {
return fmt.Errorf("no tool found for action method: %s", action.Method)
}
@@ -94,7 +96,7 @@ func (dExt *XTDriver) ExecuteAction(action MobileAction) (err error) {
}
// Execute via MCP tool
result, err := dExt.Client.CallTool(context.Background(), req)
result, err := dExt.client.CallTool(ctx, req)
if err != nil {
return fmt.Errorf("MCP tool call failed: %w", err)
}
@@ -139,3 +141,22 @@ func NewDeviceWithDefault(platform, serial string) (device IDevice, err error) {
return device, err
}
// SetMCPClients sets the external MCP clients for the driver
func (dExt *XTDriver) SetMCPClients(clients map[string]client.MCPClient) {
if dExt.loadedMCPClients == nil {
dExt.loadedMCPClients = make(map[string]client.MCPClient)
}
for name, client := range clients {
dExt.loadedMCPClients[name] = client
}
}
// GetMCPClient returns the MCP client for the specified server name
func (dExt *XTDriver) GetMCPClient(serverName string) (client.MCPClient, bool) {
if dExt.loadedMCPClients == nil {
return nil, false
}
client, exists := dExt.loadedMCPClients[serverName]
return client, exists
}