feat: enhance signal handling and graceful shutdown for MCP integration

This commit is contained in:
lilong.129
2025-05-29 00:59:17 +08:00
parent c5fb391ef5
commit d3011d467e
7 changed files with 299 additions and 46 deletions

View File

@@ -6,8 +6,11 @@ import (
"fmt"
"io"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
"github.com/cloudwego/eino/components/tool"
@@ -26,6 +29,9 @@ type MCPHost struct {
connections map[string]*Connection
config *MCPConfig
withUIXT bool
ctx context.Context
cancel context.CancelFunc
shutdownCh chan struct{}
}
// Connection represents a connection to an MCP server
@@ -48,14 +54,22 @@ func NewMCPHost(configPath string, withUIXT bool) (*MCPHost, error) {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
host := &MCPHost{
connections: make(map[string]*Connection),
config: config,
withUIXT: withUIXT,
ctx: ctx,
cancel: cancel,
shutdownCh: make(chan struct{}),
}
// Set up signal handling
go host.handleSignals()
// Initialize MCP servers
if err := host.InitServers(context.Background()); err != nil {
if err := host.InitServers(ctx); err != nil {
cancel()
return nil, fmt.Errorf("failed to initialize MCP servers: %w", err)
}
@@ -93,6 +107,13 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
log.Debug().Str("server", serverName).Msg("connecting to MCP server")
// Check if context is cancelled
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Close existing connection if any
if existing, exists := h.connections[serverName]; exists {
if err := existing.Client.Close(); err != nil {
@@ -119,9 +140,12 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
}
mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...)
if stdioClient, ok := mcpClient.(*client.Client); ok {
stderr, _ := client.GetStderr(stdioClient)
startStdioLog(stderr, serverName)
if err == nil {
if stdioClient, ok := mcpClient.(*client.Client); ok {
stderr, _ := client.GetStderr(stdioClient)
startStdioLog(stderr, serverName, h.ctx)
log.Debug().Str("server", serverName).Msg("STDIO MCP server started")
}
}
default:
return fmt.Errorf("unsupported transport type: %s", config.GetType())
@@ -131,8 +155,11 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
return fmt.Errorf("failed to create client: %w", err)
}
// initialize client
_, err = mcpClient.Initialize(ctx, prepareClientInitRequest())
// initialize client with timeout
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
_, err = mcpClient.Initialize(initCtx, prepareClientInitRequest())
if err != nil {
mcpClient.Close()
return errors.Wrapf(err, "initialize MCP client for %s failed", serverName)
@@ -152,18 +179,59 @@ func (h *MCPHost) CloseServers() error {
defer h.mu.Unlock()
log.Info().Msg("Shutting down MCP servers...")
// Use a longer timeout for graceful shutdown
timeout := 5 * time.Second
for name, conn := range h.connections {
if err := conn.Client.Close(); err != nil {
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
} else {
delete(h.connections, name)
log.Info().Str("name", name).Msg("Server closed")
// Create a timeout context for each server
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// Close server in a goroutine with timeout
done := make(chan error, 1)
go func(serverName string, client client.MCPClient) {
done <- client.Close()
}(name, conn.Client)
select {
case err := <-done:
if err != nil {
// Check if it's a signal-related error (expected during CTRL+C)
if isSignalError(err) {
log.Debug().Str("name", name).Err(err).
Msg("Server terminated by signal (expected during shutdown)")
} else {
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
}
} else {
log.Info().Str("name", name).Msg("Server closed gracefully")
}
case <-ctx.Done():
log.Warn().Str("name", name).Msg("Server close timeout, forcing termination")
}
cancel()
delete(h.connections, name)
}
return nil
}
// isSignalError checks if the error is caused by signal interruption
func isSignalError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// Common signal-related error patterns
return strings.Contains(errStr, "signal: interrupt") ||
strings.Contains(errStr, "signal: terminated") ||
strings.Contains(errStr, "exit status 120") ||
strings.Contains(errStr, "exit status 130") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset")
}
// GetClient returns the client for the specified server
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
h.mu.RLock()
@@ -244,6 +312,15 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc
func (h *MCPHost) InvokeTool(ctx context.Context,
serverName, toolName string, arguments map[string]any,
) (*mcp.CallToolResult, error) {
// Check if host is shutting down or context is cancelled
select {
case <-h.shutdownCh:
return nil, fmt.Errorf("MCP host is shutting down")
case <-ctx.Done():
return nil, ctx.Err()
default:
}
log.Info().Str("tool", toolName).Interface("args", arguments).
Str("server", serverName).Msg("invoke tool")
@@ -272,11 +349,26 @@ func (h *MCPHost) InvokeTool(ctx context.Context,
},
}
result, err := conn.CallTool(ctx, req)
// Add shorter timeout for tool invocation
toolCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
// Call tool and wait for result or cancellation
result, err := conn.CallTool(toolCtx, req)
if err != nil {
return nil, errors.Wrapf(err,
"call tool %s/%s failed", serverName, toolName)
// Check if it's a timeout or cancellation
select {
case <-h.shutdownCh:
return nil, fmt.Errorf("MCP host is shutting down")
case <-ctx.Done():
return nil, ctx.Err()
case <-toolCtx.Done():
return nil, fmt.Errorf("tool call timeout: %s/%s", serverName, toolName)
default:
return nil, errors.Wrapf(err, "call tool %s/%s failed", serverName, toolName)
}
}
if result.IsError {
if len(result.Content) > 0 {
return nil, fmt.Errorf("invoke tool %s/%s failed: %v",
@@ -366,11 +458,25 @@ func parseHeaders(headerList []string) map[string]string {
}
// startStdioLog starts a goroutine to print stdio logs
func startStdioLog(stderr io.Reader, serverName string) {
func startStdioLog(stderr io.Reader, serverName string, ctx context.Context) {
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
for {
select {
case <-ctx.Done():
log.Debug().Str("server", serverName).Msg("stopping stdio log due to context cancellation")
return
default:
if scanner.Scan() {
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
} else {
// Scanner finished or encountered error
if err := scanner.Err(); err != nil {
log.Debug().Str("server", serverName).Err(err).Msg("stdio log scanner error")
}
return
}
}
}
}()
}
@@ -392,3 +498,63 @@ func prepareClientInitRequest() mcp.InitializeRequest {
},
}
}
// handleSignals handles OS signals for graceful shutdown
func (h *MCPHost) handleSignals() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigCh:
log.Info().Str("signal", sig.String()).Msg("received signal, shutting down MCP servers")
h.Shutdown()
case <-h.ctx.Done():
return
}
}
// Shutdown gracefully shuts down all MCP servers
func (h *MCPHost) Shutdown() {
log.Debug().Msg("Starting MCP host shutdown")
h.cancel()
// Close shutdown channel to signal shutdown
select {
case <-h.shutdownCh:
// Already shutting down
log.Debug().Msg("MCP host already shutting down")
return
default:
close(h.shutdownCh)
}
// Close all servers with timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
h.CloseServers()
}()
select {
case <-done:
log.Info().Msg("MCP servers shut down gracefully")
case <-ctx.Done():
log.Warn().Msg("MCP servers shutdown timeout, forcing exit")
// Force close any remaining connections
h.forceCloseAll()
}
}
// forceCloseAll forcefully closes all remaining connections
func (h *MCPHost) forceCloseAll() {
h.mu.Lock()
defer h.mu.Unlock()
for name := range h.connections {
log.Warn().Str("name", name).Msg("Force closing server")
delete(h.connections, name)
}
}

View File

@@ -23,6 +23,12 @@
"ABC": "123"
}
},
"evalpkgs": {
"command": "/Users/debugtalk/MyProjects/ByteDance/evalpkgs/dist/mcpserver",
"args": [],
"env": {
}
},
"disabled_server": {
"command": "echo",
"args": ["disabled"],