From d3011d467ea418c794b78206fd1ee7846a20b359 Mon Sep 17 00:00:00 2001 From: "lilong.129" Date: Thu, 29 May 2025 00:59:17 +0800 Subject: [PATCH] feat: enhance signal handling and graceful shutdown for MCP integration --- internal/version/VERSION | 2 +- mcphost/host.go | 200 ++++++++++++++++++++++++++++++--- mcphost/testdata/test.mcp.json | 6 + runner.go | 59 +++++++++- step_thinktime.go | 16 ++- step_ui.go | 36 +++--- uixt/driver_handler.go | 26 ++++- 7 files changed, 299 insertions(+), 46 deletions(-) diff --git a/internal/version/VERSION b/internal/version/VERSION index db149f9e..daaee069 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2505290011 +v5.0.0-beta-2505290059 diff --git a/mcphost/host.go b/mcphost/host.go index 5935539b..ae3573d6 100644 --- a/mcphost/host.go +++ b/mcphost/host.go @@ -6,8 +6,11 @@ import ( "fmt" "io" "os" + "os/signal" "strings" "sync" + "syscall" + "time" mcpp "github.com/cloudwego/eino-ext/components/tool/mcp" "github.com/cloudwego/eino/components/tool" @@ -26,6 +29,9 @@ type MCPHost struct { connections map[string]*Connection config *MCPConfig withUIXT bool + ctx context.Context + cancel context.CancelFunc + shutdownCh chan struct{} } // Connection represents a connection to an MCP server @@ -48,14 +54,22 @@ func NewMCPHost(configPath string, withUIXT bool) (*MCPHost, error) { return nil, err } + ctx, cancel := context.WithCancel(context.Background()) host := &MCPHost{ connections: make(map[string]*Connection), config: config, withUIXT: withUIXT, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), } + // Set up signal handling + go host.handleSignals() + // Initialize MCP servers - if err := host.InitServers(context.Background()); err != nil { + if err := host.InitServers(ctx); err != nil { + cancel() return nil, fmt.Errorf("failed to initialize MCP servers: %w", err) } @@ -93,6 +107,13 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config log.Debug().Str("server", serverName).Msg("connecting to MCP server") + // Check if context is cancelled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + // Close existing connection if any if existing, exists := h.connections[serverName]; exists { if err := existing.Client.Close(); err != nil { @@ -119,9 +140,12 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config } mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...) - if stdioClient, ok := mcpClient.(*client.Client); ok { - stderr, _ := client.GetStderr(stdioClient) - startStdioLog(stderr, serverName) + if err == nil { + if stdioClient, ok := mcpClient.(*client.Client); ok { + stderr, _ := client.GetStderr(stdioClient) + startStdioLog(stderr, serverName, h.ctx) + log.Debug().Str("server", serverName).Msg("STDIO MCP server started") + } } default: return fmt.Errorf("unsupported transport type: %s", config.GetType()) @@ -131,8 +155,11 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config return fmt.Errorf("failed to create client: %w", err) } - // initialize client - _, err = mcpClient.Initialize(ctx, prepareClientInitRequest()) + // initialize client with timeout + initCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + _, err = mcpClient.Initialize(initCtx, prepareClientInitRequest()) if err != nil { mcpClient.Close() return errors.Wrapf(err, "initialize MCP client for %s failed", serverName) @@ -152,18 +179,59 @@ func (h *MCPHost) CloseServers() error { defer h.mu.Unlock() log.Info().Msg("Shutting down MCP servers...") + + // Use a longer timeout for graceful shutdown + timeout := 5 * time.Second + for name, conn := range h.connections { - if err := conn.Client.Close(); err != nil { - log.Error().Str("name", name).Err(err).Msg("Failed to close server") - } else { - delete(h.connections, name) - log.Info().Str("name", name).Msg("Server closed") + // Create a timeout context for each server + ctx, cancel := context.WithTimeout(context.Background(), timeout) + + // Close server in a goroutine with timeout + done := make(chan error, 1) + go func(serverName string, client client.MCPClient) { + done <- client.Close() + }(name, conn.Client) + + select { + case err := <-done: + if err != nil { + // Check if it's a signal-related error (expected during CTRL+C) + if isSignalError(err) { + log.Debug().Str("name", name).Err(err). + Msg("Server terminated by signal (expected during shutdown)") + } else { + log.Error().Str("name", name).Err(err).Msg("Failed to close server") + } + } else { + log.Info().Str("name", name).Msg("Server closed gracefully") + } + case <-ctx.Done(): + log.Warn().Str("name", name).Msg("Server close timeout, forcing termination") } + + cancel() + delete(h.connections, name) } return nil } +// isSignalError checks if the error is caused by signal interruption +func isSignalError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + // Common signal-related error patterns + return strings.Contains(errStr, "signal: interrupt") || + strings.Contains(errStr, "signal: terminated") || + strings.Contains(errStr, "exit status 120") || + strings.Contains(errStr, "exit status 130") || + strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "connection reset") +} + // GetClient returns the client for the specified server func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) { h.mu.RLock() @@ -244,6 +312,15 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc func (h *MCPHost) InvokeTool(ctx context.Context, serverName, toolName string, arguments map[string]any, ) (*mcp.CallToolResult, error) { + // Check if host is shutting down or context is cancelled + select { + case <-h.shutdownCh: + return nil, fmt.Errorf("MCP host is shutting down") + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + log.Info().Str("tool", toolName).Interface("args", arguments). Str("server", serverName).Msg("invoke tool") @@ -272,11 +349,26 @@ func (h *MCPHost) InvokeTool(ctx context.Context, }, } - result, err := conn.CallTool(ctx, req) + // Add shorter timeout for tool invocation + toolCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + // Call tool and wait for result or cancellation + result, err := conn.CallTool(toolCtx, req) if err != nil { - return nil, errors.Wrapf(err, - "call tool %s/%s failed", serverName, toolName) + // Check if it's a timeout or cancellation + select { + case <-h.shutdownCh: + return nil, fmt.Errorf("MCP host is shutting down") + case <-ctx.Done(): + return nil, ctx.Err() + case <-toolCtx.Done(): + return nil, fmt.Errorf("tool call timeout: %s/%s", serverName, toolName) + default: + return nil, errors.Wrapf(err, "call tool %s/%s failed", serverName, toolName) + } } + if result.IsError { if len(result.Content) > 0 { return nil, fmt.Errorf("invoke tool %s/%s failed: %v", @@ -366,11 +458,25 @@ func parseHeaders(headerList []string) map[string]string { } // startStdioLog starts a goroutine to print stdio logs -func startStdioLog(stderr io.Reader, serverName string) { +func startStdioLog(stderr io.Reader, serverName string, ctx context.Context) { go func() { scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text()) + for { + select { + case <-ctx.Done(): + log.Debug().Str("server", serverName).Msg("stopping stdio log due to context cancellation") + return + default: + if scanner.Scan() { + fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text()) + } else { + // Scanner finished or encountered error + if err := scanner.Err(); err != nil { + log.Debug().Str("server", serverName).Err(err).Msg("stdio log scanner error") + } + return + } + } } }() } @@ -392,3 +498,63 @@ func prepareClientInitRequest() mcp.InitializeRequest { }, } } + +// handleSignals handles OS signals for graceful shutdown +func (h *MCPHost) handleSignals() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case sig := <-sigCh: + log.Info().Str("signal", sig.String()).Msg("received signal, shutting down MCP servers") + h.Shutdown() + case <-h.ctx.Done(): + return + } +} + +// Shutdown gracefully shuts down all MCP servers +func (h *MCPHost) Shutdown() { + log.Debug().Msg("Starting MCP host shutdown") + h.cancel() + + // Close shutdown channel to signal shutdown + select { + case <-h.shutdownCh: + // Already shutting down + log.Debug().Msg("MCP host already shutting down") + return + default: + close(h.shutdownCh) + } + + // Close all servers with timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + h.CloseServers() + }() + + select { + case <-done: + log.Info().Msg("MCP servers shut down gracefully") + case <-ctx.Done(): + log.Warn().Msg("MCP servers shutdown timeout, forcing exit") + // Force close any remaining connections + h.forceCloseAll() + } +} + +// forceCloseAll forcefully closes all remaining connections +func (h *MCPHost) forceCloseAll() { + h.mu.Lock() + defer h.mu.Unlock() + + for name := range h.connections { + log.Warn().Str("name", name).Msg("Force closing server") + delete(h.connections, name) + } +} diff --git a/mcphost/testdata/test.mcp.json b/mcphost/testdata/test.mcp.json index 37c09fab..7d17ed2f 100644 --- a/mcphost/testdata/test.mcp.json +++ b/mcphost/testdata/test.mcp.json @@ -23,6 +23,12 @@ "ABC": "123" } }, + "evalpkgs": { + "command": "/Users/debugtalk/MyProjects/ByteDance/evalpkgs/dist/mcpserver", + "args": [], + "env": { + } + }, "disabled_server": { "command": "echo", "args": ["disabled"], diff --git a/runner.go b/runner.go index b0e131e2..bdaa7237 100644 --- a/runner.go +++ b/runner.go @@ -13,6 +13,7 @@ import ( "reflect" "strconv" "strings" + "sync" "syscall" "testing" "time" @@ -225,19 +226,54 @@ func (r *HRPRunner) Run(testcases ...ITestCase) (err error) { return err } - // quit all plugins + // collect all MCP hosts for cleanup + var mcpHosts []*mcphost.MCPHost + var cleanupOnce sync.Once + + // quit all plugins and close MCP hosts defer func() { - pluginMap.Range(func(key, value interface{}) bool { - if plugin, ok := value.(funplugin.IPlugin); ok { - plugin.Quit() + cleanupOnce.Do(func() { + pluginMap.Range(func(key, value interface{}) bool { + if plugin, ok := value.(funplugin.IPlugin); ok { + plugin.Quit() + } + return true + }) + + // Close all MCP hosts with timeout + if len(mcpHosts) > 0 { + done := make(chan struct{}) + go func() { + defer close(done) + for _, host := range mcpHosts { + if host != nil { + host.Shutdown() + } + } + }() + + // Wait for cleanup with timeout + select { + case <-done: + log.Debug().Msg("All MCP hosts cleaned up successfully") + case <-time.After(10 * time.Second): + log.Warn().Msg("MCP hosts cleanup timeout") + } } - return true }) }() var runErr error // run testcase one by one for _, testcase := range testCases { + // check for interrupt signal before processing each testcase + select { + case <-r.interruptSignal: + log.Warn().Msg("interrupted in main runner") + return errors.Wrap(code.InterruptError, "main runner interrupted") + default: + } + // each testcase has its own case runner caseRunner, err := NewCaseRunner(*testcase, r) if err != nil { @@ -245,7 +281,20 @@ func (r *HRPRunner) Run(testcases ...ITestCase) (err error) { return err } + // collect MCP host for cleanup + if caseRunner.parser.MCPHost != nil { + mcpHosts = append(mcpHosts, caseRunner.parser.MCPHost) + } + for it := caseRunner.parametersIterator; it.HasNext(); { + // check for interrupt signal before each iteration + select { + case <-r.interruptSignal: + log.Warn().Msg("interrupted in main runner") + return errors.Wrap(code.InterruptError, "main runner interrupted") + default: + } + // case runner can run multiple times with different parameters // each run has its own session runner sessionRunner := caseRunner.NewSession() diff --git a/step_thinktime.go b/step_thinktime.go index 596ad676..0c09ec2d 100644 --- a/step_thinktime.go +++ b/step_thinktime.go @@ -1,6 +1,7 @@ package hrp import ( + "fmt" "time" "github.com/rs/zerolog/log" @@ -76,6 +77,19 @@ func (s *StepThinkTime) Run(r *SessionRunner) (*StepResult, error) { } } - time.Sleep(tt) + // Use interruptible sleep that can respond to signals + log.Debug().Float64("duration_ms", float64(tt.Milliseconds())).Msg("starting think time") + + select { + case <-time.After(tt): + // Normal completion + log.Debug().Float64("duration_ms", float64(tt.Milliseconds())).Msg("think time completed normally") + case <-r.caseRunner.hrpRunner.interruptSignal: + // Interrupted by signal + log.Info().Float64("planned_duration_ms", float64(tt.Milliseconds())). + Msg("think time interrupted by signal") + return stepResult, fmt.Errorf("think time interrupted") + } + return stepResult, nil } diff --git a/step_ui.go b/step_ui.go index 4b1e0265..62ec4766 100644 --- a/step_ui.go +++ b/step_ui.go @@ -741,23 +741,21 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err attachments["error"] = err.Error() // save foreground app - if uiDriver != nil { - startTime := time.Now() - actionResult := &ActionResult{ - MobileAction: uixt.MobileAction{ - Method: option.ACTION_GetForegroundApp, - Params: "[ForDebug] check foreground app", - }, - StartTime: startTime.Unix(), - } - if app, err1 := uiDriver.ForegroundInfo(); err1 == nil { - attachments["foreground_app"] = app.AppBaseInfo - } else { - log.Warn().Err(err1).Msg("save foreground app failed, ignore") - } - actionResult.Elapsed = time.Since(startTime).Milliseconds() - stepResult.Actions = append(stepResult.Actions, actionResult) + startTime := time.Now() + actionResult := &ActionResult{ + MobileAction: uixt.MobileAction{ + Method: option.ACTION_GetForegroundApp, + Params: "[ForDebug] check foreground app", + }, + StartTime: startTime.Unix(), } + if app, err1 := uiDriver.ForegroundInfo(); err1 == nil { + attachments["foreground_app"] = app.AppBaseInfo + } else { + log.Warn().Err(err1).Msg("save foreground app failed, ignore") + } + actionResult.Elapsed = time.Since(startTime).Milliseconds() + stepResult.Actions = append(stepResult.Actions, actionResult) } // automatic handling of pop-up windows on each step finished @@ -782,10 +780,8 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err } // save attachments - if uiDriver != nil { - for key, value := range uiDriver.GetData(true) { - attachments[key] = value - } + for key, value := range uiDriver.GetData(true) { + attachments[key] = value } stepResult.Attachments = attachments stepResult.Elapsed = time.Since(start).Milliseconds() diff --git a/uixt/driver_handler.go b/uixt/driver_handler.go index cf26a064..b9422ca9 100644 --- a/uixt/driver_handler.go +++ b/uixt/driver_handler.go @@ -194,11 +194,33 @@ func callMCPActionTool(driver IDriver, // Get XTDriver from cache dExt := getXTDriverFromCache(driver) if dExt == nil { + log.Warn().Msg("XTDriver not found in cache, skipping MCP tool call") return } - dExt.CallMCPTool(context.Background(), - serverName, actionType, arguments) + // Create a context with timeout that can be cancelled + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + log.Debug().Str("server", serverName).Str("action", actionType). + Interface("arguments", arguments).Msg("calling MCP action tool") + + // Call MCP tool with timeout context + result, err := dExt.CallMCPTool(ctx, serverName, actionType, arguments) + if err != nil { + // Classify error types for better debugging + if ctx.Err() == context.DeadlineExceeded { + log.Warn().Str("server", serverName).Str("action", actionType). + Msg("MCP action tool call timeout") + } else { + log.Warn().Err(err).Str("server", serverName).Str("action", actionType). + Msg("MCP action tool call failed") + } + return + } + + log.Debug().Str("server", serverName).Str("action", actionType). + Interface("result", result).Msg("MCP action tool call succeeded") } // getAntiRisk_SetTouchInfo_Arguments gets arguments for SetTouchInfo MCP tool