mirror of
https://github.com/httprunner/httprunner.git
synced 2026-06-25 17:44:02 +08:00
feat: enhance signal handling and graceful shutdown for MCP integration
This commit is contained in:
@@ -1 +1 @@
|
||||
v5.0.0-beta-2505290011
|
||||
v5.0.0-beta-2505290059
|
||||
|
||||
200
mcphost/host.go
200
mcphost/host.go
@@ -6,8 +6,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
@@ -26,6 +29,9 @@ type MCPHost struct {
|
||||
connections map[string]*Connection
|
||||
config *MCPConfig
|
||||
withUIXT bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
shutdownCh chan struct{}
|
||||
}
|
||||
|
||||
// Connection represents a connection to an MCP server
|
||||
@@ -48,14 +54,22 @@ func NewMCPHost(configPath string, withUIXT bool) (*MCPHost, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
host := &MCPHost{
|
||||
connections: make(map[string]*Connection),
|
||||
config: config,
|
||||
withUIXT: withUIXT,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Set up signal handling
|
||||
go host.handleSignals()
|
||||
|
||||
// Initialize MCP servers
|
||||
if err := host.InitServers(context.Background()); err != nil {
|
||||
if err := host.InitServers(ctx); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to initialize MCP servers: %w", err)
|
||||
}
|
||||
|
||||
@@ -93,6 +107,13 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
|
||||
|
||||
log.Debug().Str("server", serverName).Msg("connecting to MCP server")
|
||||
|
||||
// Check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Close existing connection if any
|
||||
if existing, exists := h.connections[serverName]; exists {
|
||||
if err := existing.Client.Close(); err != nil {
|
||||
@@ -119,9 +140,12 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
|
||||
}
|
||||
|
||||
mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...)
|
||||
if stdioClient, ok := mcpClient.(*client.Client); ok {
|
||||
stderr, _ := client.GetStderr(stdioClient)
|
||||
startStdioLog(stderr, serverName)
|
||||
if err == nil {
|
||||
if stdioClient, ok := mcpClient.(*client.Client); ok {
|
||||
stderr, _ := client.GetStderr(stdioClient)
|
||||
startStdioLog(stderr, serverName, h.ctx)
|
||||
log.Debug().Str("server", serverName).Msg("STDIO MCP server started")
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported transport type: %s", config.GetType())
|
||||
@@ -131,8 +155,11 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
|
||||
return fmt.Errorf("failed to create client: %w", err)
|
||||
}
|
||||
|
||||
// initialize client
|
||||
_, err = mcpClient.Initialize(ctx, prepareClientInitRequest())
|
||||
// initialize client with timeout
|
||||
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = mcpClient.Initialize(initCtx, prepareClientInitRequest())
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return errors.Wrapf(err, "initialize MCP client for %s failed", serverName)
|
||||
@@ -152,18 +179,59 @@ func (h *MCPHost) CloseServers() error {
|
||||
defer h.mu.Unlock()
|
||||
|
||||
log.Info().Msg("Shutting down MCP servers...")
|
||||
|
||||
// Use a longer timeout for graceful shutdown
|
||||
timeout := 5 * time.Second
|
||||
|
||||
for name, conn := range h.connections {
|
||||
if err := conn.Client.Close(); err != nil {
|
||||
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
|
||||
} else {
|
||||
delete(h.connections, name)
|
||||
log.Info().Str("name", name).Msg("Server closed")
|
||||
// Create a timeout context for each server
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
// Close server in a goroutine with timeout
|
||||
done := make(chan error, 1)
|
||||
go func(serverName string, client client.MCPClient) {
|
||||
done <- client.Close()
|
||||
}(name, conn.Client)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
// Check if it's a signal-related error (expected during CTRL+C)
|
||||
if isSignalError(err) {
|
||||
log.Debug().Str("name", name).Err(err).
|
||||
Msg("Server terminated by signal (expected during shutdown)")
|
||||
} else {
|
||||
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
|
||||
}
|
||||
} else {
|
||||
log.Info().Str("name", name).Msg("Server closed gracefully")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Warn().Str("name", name).Msg("Server close timeout, forcing termination")
|
||||
}
|
||||
|
||||
cancel()
|
||||
delete(h.connections, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isSignalError checks if the error is caused by signal interruption
|
||||
func isSignalError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
// Common signal-related error patterns
|
||||
return strings.Contains(errStr, "signal: interrupt") ||
|
||||
strings.Contains(errStr, "signal: terminated") ||
|
||||
strings.Contains(errStr, "exit status 120") ||
|
||||
strings.Contains(errStr, "exit status 130") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset")
|
||||
}
|
||||
|
||||
// GetClient returns the client for the specified server
|
||||
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
|
||||
h.mu.RLock()
|
||||
@@ -244,6 +312,15 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc
|
||||
func (h *MCPHost) InvokeTool(ctx context.Context,
|
||||
serverName, toolName string, arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
// Check if host is shutting down or context is cancelled
|
||||
select {
|
||||
case <-h.shutdownCh:
|
||||
return nil, fmt.Errorf("MCP host is shutting down")
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info().Str("tool", toolName).Interface("args", arguments).
|
||||
Str("server", serverName).Msg("invoke tool")
|
||||
|
||||
@@ -272,11 +349,26 @@ func (h *MCPHost) InvokeTool(ctx context.Context,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := conn.CallTool(ctx, req)
|
||||
// Add shorter timeout for tool invocation
|
||||
toolCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call tool and wait for result or cancellation
|
||||
result, err := conn.CallTool(toolCtx, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"call tool %s/%s failed", serverName, toolName)
|
||||
// Check if it's a timeout or cancellation
|
||||
select {
|
||||
case <-h.shutdownCh:
|
||||
return nil, fmt.Errorf("MCP host is shutting down")
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-toolCtx.Done():
|
||||
return nil, fmt.Errorf("tool call timeout: %s/%s", serverName, toolName)
|
||||
default:
|
||||
return nil, errors.Wrapf(err, "call tool %s/%s failed", serverName, toolName)
|
||||
}
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
if len(result.Content) > 0 {
|
||||
return nil, fmt.Errorf("invoke tool %s/%s failed: %v",
|
||||
@@ -366,11 +458,25 @@ func parseHeaders(headerList []string) map[string]string {
|
||||
}
|
||||
|
||||
// startStdioLog starts a goroutine to print stdio logs
|
||||
func startStdioLog(stderr io.Reader, serverName string) {
|
||||
func startStdioLog(stderr io.Reader, serverName string, ctx context.Context) {
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug().Str("server", serverName).Msg("stopping stdio log due to context cancellation")
|
||||
return
|
||||
default:
|
||||
if scanner.Scan() {
|
||||
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
|
||||
} else {
|
||||
// Scanner finished or encountered error
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Debug().Str("server", serverName).Err(err).Msg("stdio log scanner error")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -392,3 +498,63 @@ func prepareClientInitRequest() mcp.InitializeRequest {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleSignals handles OS signals for graceful shutdown
|
||||
func (h *MCPHost) handleSignals() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case sig := <-sigCh:
|
||||
log.Info().Str("signal", sig.String()).Msg("received signal, shutting down MCP servers")
|
||||
h.Shutdown()
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down all MCP servers
|
||||
func (h *MCPHost) Shutdown() {
|
||||
log.Debug().Msg("Starting MCP host shutdown")
|
||||
h.cancel()
|
||||
|
||||
// Close shutdown channel to signal shutdown
|
||||
select {
|
||||
case <-h.shutdownCh:
|
||||
// Already shutting down
|
||||
log.Debug().Msg("MCP host already shutting down")
|
||||
return
|
||||
default:
|
||||
close(h.shutdownCh)
|
||||
}
|
||||
|
||||
// Close all servers with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
h.CloseServers()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Info().Msg("MCP servers shut down gracefully")
|
||||
case <-ctx.Done():
|
||||
log.Warn().Msg("MCP servers shutdown timeout, forcing exit")
|
||||
// Force close any remaining connections
|
||||
h.forceCloseAll()
|
||||
}
|
||||
}
|
||||
|
||||
// forceCloseAll forcefully closes all remaining connections
|
||||
func (h *MCPHost) forceCloseAll() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for name := range h.connections {
|
||||
log.Warn().Str("name", name).Msg("Force closing server")
|
||||
delete(h.connections, name)
|
||||
}
|
||||
}
|
||||
|
||||
6
mcphost/testdata/test.mcp.json
vendored
6
mcphost/testdata/test.mcp.json
vendored
@@ -23,6 +23,12 @@
|
||||
"ABC": "123"
|
||||
}
|
||||
},
|
||||
"evalpkgs": {
|
||||
"command": "/Users/debugtalk/MyProjects/ByteDance/evalpkgs/dist/mcpserver",
|
||||
"args": [],
|
||||
"env": {
|
||||
}
|
||||
},
|
||||
"disabled_server": {
|
||||
"command": "echo",
|
||||
"args": ["disabled"],
|
||||
|
||||
59
runner.go
59
runner.go
@@ -13,6 +13,7 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -225,19 +226,54 @@ func (r *HRPRunner) Run(testcases ...ITestCase) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// quit all plugins
|
||||
// collect all MCP hosts for cleanup
|
||||
var mcpHosts []*mcphost.MCPHost
|
||||
var cleanupOnce sync.Once
|
||||
|
||||
// quit all plugins and close MCP hosts
|
||||
defer func() {
|
||||
pluginMap.Range(func(key, value interface{}) bool {
|
||||
if plugin, ok := value.(funplugin.IPlugin); ok {
|
||||
plugin.Quit()
|
||||
cleanupOnce.Do(func() {
|
||||
pluginMap.Range(func(key, value interface{}) bool {
|
||||
if plugin, ok := value.(funplugin.IPlugin); ok {
|
||||
plugin.Quit()
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Close all MCP hosts with timeout
|
||||
if len(mcpHosts) > 0 {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for _, host := range mcpHosts {
|
||||
if host != nil {
|
||||
host.Shutdown()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for cleanup with timeout
|
||||
select {
|
||||
case <-done:
|
||||
log.Debug().Msg("All MCP hosts cleaned up successfully")
|
||||
case <-time.After(10 * time.Second):
|
||||
log.Warn().Msg("MCP hosts cleanup timeout")
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}()
|
||||
|
||||
var runErr error
|
||||
// run testcase one by one
|
||||
for _, testcase := range testCases {
|
||||
// check for interrupt signal before processing each testcase
|
||||
select {
|
||||
case <-r.interruptSignal:
|
||||
log.Warn().Msg("interrupted in main runner")
|
||||
return errors.Wrap(code.InterruptError, "main runner interrupted")
|
||||
default:
|
||||
}
|
||||
|
||||
// each testcase has its own case runner
|
||||
caseRunner, err := NewCaseRunner(*testcase, r)
|
||||
if err != nil {
|
||||
@@ -245,7 +281,20 @@ func (r *HRPRunner) Run(testcases ...ITestCase) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// collect MCP host for cleanup
|
||||
if caseRunner.parser.MCPHost != nil {
|
||||
mcpHosts = append(mcpHosts, caseRunner.parser.MCPHost)
|
||||
}
|
||||
|
||||
for it := caseRunner.parametersIterator; it.HasNext(); {
|
||||
// check for interrupt signal before each iteration
|
||||
select {
|
||||
case <-r.interruptSignal:
|
||||
log.Warn().Msg("interrupted in main runner")
|
||||
return errors.Wrap(code.InterruptError, "main runner interrupted")
|
||||
default:
|
||||
}
|
||||
|
||||
// case runner can run multiple times with different parameters
|
||||
// each run has its own session runner
|
||||
sessionRunner := caseRunner.NewSession()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package hrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -76,6 +77,19 @@ func (s *StepThinkTime) Run(r *SessionRunner) (*StepResult, error) {
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(tt)
|
||||
// Use interruptible sleep that can respond to signals
|
||||
log.Debug().Float64("duration_ms", float64(tt.Milliseconds())).Msg("starting think time")
|
||||
|
||||
select {
|
||||
case <-time.After(tt):
|
||||
// Normal completion
|
||||
log.Debug().Float64("duration_ms", float64(tt.Milliseconds())).Msg("think time completed normally")
|
||||
case <-r.caseRunner.hrpRunner.interruptSignal:
|
||||
// Interrupted by signal
|
||||
log.Info().Float64("planned_duration_ms", float64(tt.Milliseconds())).
|
||||
Msg("think time interrupted by signal")
|
||||
return stepResult, fmt.Errorf("think time interrupted")
|
||||
}
|
||||
|
||||
return stepResult, nil
|
||||
}
|
||||
|
||||
36
step_ui.go
36
step_ui.go
@@ -741,23 +741,21 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
|
||||
attachments["error"] = err.Error()
|
||||
|
||||
// save foreground app
|
||||
if uiDriver != nil {
|
||||
startTime := time.Now()
|
||||
actionResult := &ActionResult{
|
||||
MobileAction: uixt.MobileAction{
|
||||
Method: option.ACTION_GetForegroundApp,
|
||||
Params: "[ForDebug] check foreground app",
|
||||
},
|
||||
StartTime: startTime.Unix(),
|
||||
}
|
||||
if app, err1 := uiDriver.ForegroundInfo(); err1 == nil {
|
||||
attachments["foreground_app"] = app.AppBaseInfo
|
||||
} else {
|
||||
log.Warn().Err(err1).Msg("save foreground app failed, ignore")
|
||||
}
|
||||
actionResult.Elapsed = time.Since(startTime).Milliseconds()
|
||||
stepResult.Actions = append(stepResult.Actions, actionResult)
|
||||
startTime := time.Now()
|
||||
actionResult := &ActionResult{
|
||||
MobileAction: uixt.MobileAction{
|
||||
Method: option.ACTION_GetForegroundApp,
|
||||
Params: "[ForDebug] check foreground app",
|
||||
},
|
||||
StartTime: startTime.Unix(),
|
||||
}
|
||||
if app, err1 := uiDriver.ForegroundInfo(); err1 == nil {
|
||||
attachments["foreground_app"] = app.AppBaseInfo
|
||||
} else {
|
||||
log.Warn().Err(err1).Msg("save foreground app failed, ignore")
|
||||
}
|
||||
actionResult.Elapsed = time.Since(startTime).Milliseconds()
|
||||
stepResult.Actions = append(stepResult.Actions, actionResult)
|
||||
}
|
||||
|
||||
// automatic handling of pop-up windows on each step finished
|
||||
@@ -782,10 +780,8 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
|
||||
}
|
||||
|
||||
// save attachments
|
||||
if uiDriver != nil {
|
||||
for key, value := range uiDriver.GetData(true) {
|
||||
attachments[key] = value
|
||||
}
|
||||
for key, value := range uiDriver.GetData(true) {
|
||||
attachments[key] = value
|
||||
}
|
||||
stepResult.Attachments = attachments
|
||||
stepResult.Elapsed = time.Since(start).Milliseconds()
|
||||
|
||||
@@ -194,11 +194,33 @@ func callMCPActionTool(driver IDriver,
|
||||
// Get XTDriver from cache
|
||||
dExt := getXTDriverFromCache(driver)
|
||||
if dExt == nil {
|
||||
log.Warn().Msg("XTDriver not found in cache, skipping MCP tool call")
|
||||
return
|
||||
}
|
||||
|
||||
dExt.CallMCPTool(context.Background(),
|
||||
serverName, actionType, arguments)
|
||||
// Create a context with timeout that can be cancelled
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
log.Debug().Str("server", serverName).Str("action", actionType).
|
||||
Interface("arguments", arguments).Msg("calling MCP action tool")
|
||||
|
||||
// Call MCP tool with timeout context
|
||||
result, err := dExt.CallMCPTool(ctx, serverName, actionType, arguments)
|
||||
if err != nil {
|
||||
// Classify error types for better debugging
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
log.Warn().Str("server", serverName).Str("action", actionType).
|
||||
Msg("MCP action tool call timeout")
|
||||
} else {
|
||||
log.Warn().Err(err).Str("server", serverName).Str("action", actionType).
|
||||
Msg("MCP action tool call failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Str("server", serverName).Str("action", actionType).
|
||||
Interface("result", result).Msg("MCP action tool call succeeded")
|
||||
}
|
||||
|
||||
// getAntiRisk_SetTouchInfo_Arguments gets arguments for SetTouchInfo MCP tool
|
||||
|
||||
Reference in New Issue
Block a user