mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-06 20:32:44 +08:00
feat: implement MCP hooks integration with anti_risk option
This commit is contained in:
@@ -1 +1 @@
|
||||
v5.0.0-beta-2505271534
|
||||
v5.0.0-beta-2505271946
|
||||
|
||||
@@ -26,7 +26,6 @@ type MCPHost struct {
|
||||
connections map[string]*Connection
|
||||
config *MCPConfig
|
||||
withUIXT bool
|
||||
drivers map[string]*uixt.XTDriver
|
||||
}
|
||||
|
||||
// Connection represents a connection to an MCP server
|
||||
@@ -52,7 +51,6 @@ func NewMCPHost(configPath string, withUIXT bool) (*MCPHost, error) {
|
||||
host := &MCPHost{
|
||||
connections: make(map[string]*Connection),
|
||||
config: config,
|
||||
drivers: make(map[string]*uixt.XTDriver),
|
||||
withUIXT: withUIXT,
|
||||
}
|
||||
|
||||
@@ -175,6 +173,18 @@ func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
|
||||
return conn.Client, nil
|
||||
}
|
||||
|
||||
// GetAllClients returns all MCP clients
|
||||
func (h *MCPHost) GetAllClients() map[string]client.MCPClient {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
clients := make(map[string]client.MCPClient)
|
||||
for name, conn := range h.connections {
|
||||
clients[name] = conn.Client
|
||||
}
|
||||
return clients
|
||||
}
|
||||
|
||||
// GetTools returns all tools from all MCP servers
|
||||
func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
|
||||
h.mu.RLock()
|
||||
@@ -204,28 +214,20 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Get all tools
|
||||
results := h.GetTools(ctx)
|
||||
|
||||
// Find the server's tools
|
||||
var serverTools MCPTools
|
||||
found := false
|
||||
for _, tools := range results {
|
||||
if tools.ServerName == serverName {
|
||||
serverTools = tools
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Get connection for the server
|
||||
conn, exists := h.connections[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for MCP server %s", serverName)
|
||||
}
|
||||
if serverTools.Err != nil {
|
||||
return nil, serverTools.Err
|
||||
|
||||
// Get tools from the specific server
|
||||
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get tools from server %s: %w", serverName, err)
|
||||
}
|
||||
|
||||
// Find the specific tool
|
||||
for _, tool := range serverTools.Tools {
|
||||
for _, tool := range listResults.Tools {
|
||||
if tool.Name == toolName {
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
11
runner.go
11
runner.go
@@ -495,10 +495,19 @@ func (r *CaseRunner) parseConfig() (parsedConfig *TConfig, err error) {
|
||||
|
||||
// init XTDriver and register to unified cache
|
||||
for _, driverConfig := range driverConfigs {
|
||||
_, err := uixt.GetOrCreateXTDriver(driverConfig)
|
||||
driver, err := uixt.GetOrCreateXTDriver(driverConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "init %s XTDriver failed", driverConfig.Platform)
|
||||
}
|
||||
|
||||
// Set MCP clients if MCPHost is available
|
||||
if r.parser.MCPHost != nil {
|
||||
mcpClients := r.parser.MCPHost.GetAllClients()
|
||||
driver.SetMCPClients(mcpClients)
|
||||
log.Debug().Str("serial", driverConfig.Serial).
|
||||
Int("mcp_clients", len(mcpClients)).
|
||||
Msg("Set MCP clients for XTDriver")
|
||||
}
|
||||
}
|
||||
|
||||
return parsedConfig, nil
|
||||
|
||||
@@ -19,7 +19,7 @@ func (r *Router) uixtActionHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = dExt.ExecuteAction(req); err != nil {
|
||||
if err = dExt.ExecuteAction(c.Request.Context(), req); err != nil {
|
||||
log.Err(err).Interface("action", req).
|
||||
Msg("exec uixt action failed")
|
||||
RenderError(c, err)
|
||||
@@ -42,7 +42,7 @@ func (r *Router) uixtActionsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
if err = dExt.ExecuteAction(action); err != nil {
|
||||
if err = dExt.ExecuteAction(c.Request.Context(), action); err != nil {
|
||||
log.Err(err).Interface("action", action).
|
||||
Msg("exec uixt action failed")
|
||||
RenderError(c, err)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package hrp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -803,7 +804,7 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
|
||||
continue
|
||||
}
|
||||
|
||||
err = uiDriver.ExecuteAction(action)
|
||||
err = uiDriver.ExecuteAction(context.Background(), action)
|
||||
actionResult.Elapsed = time.Since(actionStartTime).Milliseconds()
|
||||
stepResult.Actions = append(stepResult.Actions, actionResult)
|
||||
if err != nil {
|
||||
|
||||
@@ -62,7 +62,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
|
||||
},
|
||||
}
|
||||
|
||||
_, err = dExt.Client.CallTool(context.Background(), req)
|
||||
_, err = dExt.client.CallTool(context.Background(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uixt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"github.com/httprunner/httprunner/v5/internal/builtin"
|
||||
"github.com/httprunner/httprunner/v5/internal/config"
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -47,6 +49,14 @@ func (dExt *XTDriver) Call(desc string, fn func(), opts ...option.ActionOption)
|
||||
func preHandler_TapAbsXY(driver IDriver, options *option.ActionOptions, rawX, rawY float64) (
|
||||
x, y float64, err error) {
|
||||
|
||||
// Call MCP action tool if anti-risk is enabled
|
||||
if options.AntiRisk {
|
||||
callMCPActionTool(driver, option.ACTION_TapAbsXY, map[string]any{
|
||||
"x": rawX,
|
||||
"y": rawY,
|
||||
})
|
||||
}
|
||||
|
||||
x, y = options.ApplyTapOffset(rawX, rawY)
|
||||
|
||||
// mark UI operation
|
||||
@@ -143,3 +153,131 @@ func postHandler(driver IDriver, actionType option.ActionName, options *option.A
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// callMCPActionTool calls MCP tool for the given action
|
||||
func callMCPActionTool(driver IDriver, actionType option.ActionName, arguments map[string]any) {
|
||||
// Get XTDriver from cache
|
||||
dExt := getXTDriverFromCache(driver)
|
||||
if dExt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Define action to MCP server mapping for pre-hooks
|
||||
serverMapping := getPreHookServerMapping(actionType)
|
||||
if serverMapping == nil {
|
||||
return // No MCP hook configured for this action
|
||||
}
|
||||
|
||||
callMCPTool(dExt, serverMapping.ServerName, serverMapping.ToolName, arguments, actionType)
|
||||
}
|
||||
|
||||
// MCPServerMapping defines the mapping between action and MCP server/tool
|
||||
type MCPServerMapping struct {
|
||||
ServerName string
|
||||
ToolName string
|
||||
}
|
||||
|
||||
// getPreHookServerMapping returns MCP server mapping for pre-hooks
|
||||
// TODO: You can customize these mappings according to your needs
|
||||
func getPreHookServerMapping(actionType option.ActionName) *MCPServerMapping {
|
||||
mappings := map[option.ActionName]*MCPServerMapping{
|
||||
option.ACTION_TapAbsXY: {
|
||||
ServerName: "evalpkgs",
|
||||
ToolName: "log_pre_action",
|
||||
},
|
||||
// Add more mappings as needed
|
||||
// option.ACTION_Swipe: {
|
||||
// ServerName: "monitor",
|
||||
// ToolName: "start_timer",
|
||||
// },
|
||||
}
|
||||
return mappings[actionType]
|
||||
}
|
||||
|
||||
// getXTDriverFromCache gets XTDriver from cache using device UUID
|
||||
func getXTDriverFromCache(driver IDriver) *XTDriver {
|
||||
// Get device info to find the corresponding XTDriver
|
||||
device := driver.GetDevice()
|
||||
if device == nil {
|
||||
log.Warn().Msg("Cannot get device from driver for MCP hook")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get device UUID (serial/udid/connectKey/browserID)
|
||||
deviceUUID := device.UUID()
|
||||
if deviceUUID == "" {
|
||||
log.Warn().Msg("Cannot get device UUID for MCP hook")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get XTDriver from cache using device UUID as serial
|
||||
cachedDrivers := ListCachedDrivers()
|
||||
for _, cached := range cachedDrivers {
|
||||
if cached.Serial == deviceUUID {
|
||||
return cached.Driver
|
||||
}
|
||||
}
|
||||
|
||||
log.Warn().Str("uuid", deviceUUID).
|
||||
Msg("Cannot find cached XTDriver for MCP hook")
|
||||
return nil
|
||||
}
|
||||
|
||||
// callMCPTool calls the specified MCP tool
|
||||
func callMCPTool(dExt *XTDriver, serverName, toolName string, arguments map[string]any, actionType option.ActionName) {
|
||||
// Get MCP client
|
||||
mcpClient, exists := dExt.GetMCPClient(serverName)
|
||||
if !exists {
|
||||
log.Debug().Str("server", serverName).Msg("MCP server not found for hook")
|
||||
return
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Prepare arguments
|
||||
if arguments == nil {
|
||||
arguments = make(map[string]any)
|
||||
}
|
||||
// Add action type and hook type to arguments
|
||||
arguments["action_type"] = string(actionType)
|
||||
|
||||
// Call MCP tool
|
||||
req := mcp.CallToolRequest{
|
||||
Params: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Meta *struct {
|
||||
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := mcpClient.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).
|
||||
Str("server", serverName).
|
||||
Str("tool", toolName).
|
||||
Msg("MCP hook call failed")
|
||||
return
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
log.Debug().
|
||||
Str("server", serverName).
|
||||
Str("tool", toolName).
|
||||
Interface("content", result.Content).
|
||||
Msg("MCP hook returned error")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("server", serverName).
|
||||
Str("tool", toolName).
|
||||
Str("action", string(actionType)).
|
||||
Msg("MCP hook called successfully")
|
||||
}
|
||||
|
||||
@@ -1071,10 +1071,10 @@ func (t *ToolSwipeCoordinate) Implement() server.ToolHandlerFunc {
|
||||
|
||||
params := []float64{unifiedReq.FromX, unifiedReq.FromY, unifiedReq.ToX, unifiedReq.ToY}
|
||||
opts := []option.ActionOption{}
|
||||
if unifiedReq.Duration > 0 && unifiedReq.Duration > 0 {
|
||||
if unifiedReq.Duration > 0 {
|
||||
opts = append(opts, option.WithDuration(unifiedReq.Duration))
|
||||
}
|
||||
if unifiedReq.PressDuration > 0 && unifiedReq.PressDuration > 0 {
|
||||
if unifiedReq.PressDuration > 0 {
|
||||
opts = append(opts, option.WithPressDuration(unifiedReq.PressDuration))
|
||||
}
|
||||
|
||||
@@ -1146,10 +1146,10 @@ func (t *ToolSwipeToTapApp) Implement() server.ToolHandlerFunc {
|
||||
}
|
||||
|
||||
// Add numeric options
|
||||
if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 {
|
||||
if unifiedReq.MaxRetryTimes > 0 {
|
||||
opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes))
|
||||
}
|
||||
if unifiedReq.Index > 0 && unifiedReq.Index > 0 {
|
||||
if unifiedReq.Index > 0 {
|
||||
opts = append(opts, option.WithIndex(unifiedReq.Index))
|
||||
}
|
||||
|
||||
@@ -1218,10 +1218,10 @@ func (t *ToolSwipeToTapText) Implement() server.ToolHandlerFunc {
|
||||
}
|
||||
|
||||
// Add numeric options
|
||||
if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 {
|
||||
if unifiedReq.MaxRetryTimes > 0 {
|
||||
opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes))
|
||||
}
|
||||
if unifiedReq.Index > 0 && unifiedReq.Index > 0 {
|
||||
if unifiedReq.Index > 0 {
|
||||
opts = append(opts, option.WithIndex(unifiedReq.Index))
|
||||
}
|
||||
|
||||
@@ -1290,10 +1290,10 @@ func (t *ToolSwipeToTapTexts) Implement() server.ToolHandlerFunc {
|
||||
}
|
||||
|
||||
// Add numeric options
|
||||
if unifiedReq.MaxRetryTimes > 0 && unifiedReq.MaxRetryTimes > 0 {
|
||||
if unifiedReq.MaxRetryTimes > 0 {
|
||||
opts = append(opts, option.WithMaxRetryTimes(unifiedReq.MaxRetryTimes))
|
||||
}
|
||||
if unifiedReq.Index > 0 && unifiedReq.Index > 0 {
|
||||
if unifiedReq.Index > 0 {
|
||||
opts = append(opts, option.WithIndex(unifiedReq.Index))
|
||||
}
|
||||
|
||||
|
||||
@@ -175,6 +175,9 @@ type ActionOptions struct {
|
||||
|
||||
ScreenOptions
|
||||
|
||||
// Anti-risk options
|
||||
AntiRisk bool `json:"anti_risk,omitempty" yaml:"anti_risk,omitempty" desc:"Enable anti-risk MCP tool calls"`
|
||||
|
||||
// Custom options
|
||||
Custom map[string]interface{} `json:"custom,omitempty" yaml:"custom,omitempty" desc:"Custom options"`
|
||||
}
|
||||
@@ -286,6 +289,10 @@ func (o *ActionOptions) Options() []ActionOption {
|
||||
options = append(options, WithMatchOne(true))
|
||||
}
|
||||
|
||||
if o.AntiRisk {
|
||||
options = append(options, WithAntiRisk(true))
|
||||
}
|
||||
|
||||
// custom options
|
||||
if o.Custom != nil {
|
||||
for k, v := range o.Custom {
|
||||
@@ -494,6 +501,12 @@ func WithIgnoreNotFoundError(ignoreError bool) ActionOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithAntiRisk(antiRisk bool) ActionOption {
|
||||
return func(o *ActionOptions) {
|
||||
o.AntiRisk = antiRisk
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP API direct usage methods
|
||||
|
||||
// ValidateForHTTPAPI validates the request for HTTP API usage
|
||||
|
||||
31
uixt/sdk.go
31
uixt/sdk.go
@@ -15,9 +15,10 @@ import (
|
||||
func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, error) {
|
||||
driverExt := &XTDriver{
|
||||
IDriver: driver,
|
||||
Client: &MCPClient4XTDriver{
|
||||
client: &MCPClient4XTDriver{
|
||||
Server: NewMCPServer(),
|
||||
},
|
||||
loadedMCPClients: make(map[string]client.MCPClient),
|
||||
}
|
||||
|
||||
services := option.NewAIServiceOptions(opts...)
|
||||
@@ -47,7 +48,8 @@ type XTDriver struct {
|
||||
CVService ai.ICVService // OCR/CV
|
||||
LLMService ai.ILLMService // LLM
|
||||
|
||||
Client *MCPClient4XTDriver // MCP Client
|
||||
client *MCPClient4XTDriver // MCP Client for built-in uixt server
|
||||
loadedMCPClients map[string]client.MCPClient // External MCP clients
|
||||
}
|
||||
|
||||
// MCPClient4XTDriver is a mock MCP client that only implements the methods used by the host
|
||||
@@ -80,9 +82,9 @@ func (c *MCPClient4XTDriver) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dExt *XTDriver) ExecuteAction(action MobileAction) (err error) {
|
||||
func (dExt *XTDriver) ExecuteAction(ctx context.Context, action MobileAction) (err error) {
|
||||
// Find the corresponding tool for this action method
|
||||
tool := dExt.Client.Server.GetToolByAction(action.Method)
|
||||
tool := dExt.client.Server.GetToolByAction(action.Method)
|
||||
if tool == nil {
|
||||
return fmt.Errorf("no tool found for action method: %s", action.Method)
|
||||
}
|
||||
@@ -94,7 +96,7 @@ func (dExt *XTDriver) ExecuteAction(action MobileAction) (err error) {
|
||||
}
|
||||
|
||||
// Execute via MCP tool
|
||||
result, err := dExt.Client.CallTool(context.Background(), req)
|
||||
result, err := dExt.client.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MCP tool call failed: %w", err)
|
||||
}
|
||||
@@ -139,3 +141,22 @@ func NewDeviceWithDefault(platform, serial string) (device IDevice, err error) {
|
||||
|
||||
return device, err
|
||||
}
|
||||
|
||||
// SetMCPClients sets the external MCP clients for the driver
|
||||
func (dExt *XTDriver) SetMCPClients(clients map[string]client.MCPClient) {
|
||||
if dExt.loadedMCPClients == nil {
|
||||
dExt.loadedMCPClients = make(map[string]client.MCPClient)
|
||||
}
|
||||
for name, client := range clients {
|
||||
dExt.loadedMCPClients[name] = client
|
||||
}
|
||||
}
|
||||
|
||||
// GetMCPClient returns the MCP client for the specified server name
|
||||
func (dExt *XTDriver) GetMCPClient(serverName string) (client.MCPClient, bool) {
|
||||
if dExt.loadedMCPClients == nil {
|
||||
return nil, false
|
||||
}
|
||||
client, exists := dExt.loadedMCPClients[serverName]
|
||||
return client, exists
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user