diff --git a/internal/version/VERSION b/internal/version/VERSION index c0405126..a4dece59 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2505161141 +v5.0.0-beta-2505161143 diff --git a/pkg/mcphost/host.go b/pkg/mcphost/host.go new file mode 100644 index 00000000..a0401df0 --- /dev/null +++ b/pkg/mcphost/host.go @@ -0,0 +1,298 @@ +package mcphost + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "strings" + "sync" + + "github.com/httprunner/httprunner/v5/internal/version" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +// MCPTools represents tools from a single MCP server +type MCPTools struct { + Name string + Tools []mcp.Tool + Err error +} + +// MCPHost manages MCP server connections and tools +type MCPHost struct { + mu sync.RWMutex + connections map[string]*Connection + config *MCPConfig +} + +// Connection represents a connection to an MCP server +type Connection struct { + Client client.MCPClient + Config ServerConfig +} + +// NewMCPHost creates a new MCPHost instance +func NewMCPHost(configPath string) (*MCPHost, error) { + config, err := LoadMCPConfig(configPath) + if err != nil { + return nil, err + } + return &MCPHost{ + connections: make(map[string]*Connection), + config: config, + }, nil +} + +// parseHeaders parses header strings into a map +func parseHeaders(headerList []string) map[string]string { + headers := make(map[string]string) + for _, header := range headerList { + parts := strings.SplitN(header, ":", 2) + if len(parts) == 2 { + headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } + } + return headers +} + +// startStdioLog starts a goroutine to print stdio logs +func startStdioLog(stderr io.Reader, serverName string) { + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text()) + } + }() +} + +// prepareClientInitRequest creates a standard initialization request +func prepareClientInitRequest() mcp.InitializeRequest { + return mcp.InitializeRequest{ + Params: struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + ClientInfo mcp.Implementation `json:"clientInfo"` + }{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: "hrp-mcphost", + Version: version.GetVersionInfo(), + }, + }, + } +} + +// InitServers initializes all MCP servers +func (h *MCPHost) InitServers(ctx context.Context) error { + for name, server := range h.config.MCPServers { + if server.Config.IsDisabled() { + continue + } + + if err := h.connectToServer(ctx, name, server.Config); err != nil { + return fmt.Errorf("failed to connect to server %s: %w", name, err) + } + } + return nil +} + +// GetClient returns the client for the specified server +func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + conn, exists := h.connections[serverName] + if !exists { + return nil, fmt.Errorf("no connection found for server %s", serverName) + } + + return conn.Client, nil +} + +// connectToServer establishes connection to a single MCP server +func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config ServerConfig) error { + h.mu.Lock() + defer h.mu.Unlock() + + log.Debug().Str("server", serverName).Msg("connecting to MCP server") + + // Close existing connection if any + if existing, exists := h.connections[serverName]; exists { + if err := existing.Client.Close(); err != nil { + return fmt.Errorf("failed to close existing connection: %w", err) + } + delete(h.connections, serverName) + } + + var mcpClient client.MCPClient + var err error + + // create client based on server type + switch cfg := config.(type) { + case SSEServerConfig: + mcpClient, err = client.NewSSEMCPClient(cfg.Url, client.WithHeaders(parseHeaders(cfg.Headers))) + case STDIOServerConfig: + env := make([]string, 0, len(cfg.Env)) + for k, v := range cfg.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...) + if stdioClient, ok := mcpClient.(*client.Client); ok { + stderr, _ := client.GetStderr(stdioClient) + startStdioLog(stderr, serverName) + } + default: + return fmt.Errorf("unsupported transport type: %s", config.GetType()) + } + + if err != nil { + return fmt.Errorf("failed to create client: %w", err) + } + + // initialize client + _, err = mcpClient.Initialize(ctx, prepareClientInitRequest()) + if err != nil { + mcpClient.Close() + return errors.Wrapf(err, "initialize MCP client for %s failed", serverName) + } + + log.Info().Str("server", serverName).Msg("connected to MCP server") + h.connections[serverName] = &Connection{ + Client: mcpClient, + Config: config, + } + return nil +} + +// GetTools fetches available tools from all connected MCP servers +func (h *MCPHost) GetTools(ctx context.Context) map[string]MCPTools { + h.mu.RLock() + defer h.mu.RUnlock() + + results := make(map[string]MCPTools) + + for serverName, conn := range h.connections { + if conn.Config.IsDisabled() { + continue + } + + listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + results[serverName] = MCPTools{ + Name: serverName, + Tools: nil, + Err: fmt.Errorf("failed to get tools: %w", err), + } + continue + } + + results[serverName] = MCPTools{ + Name: serverName, + Tools: listResults.Tools, + Err: nil, + } + } + + return results +} + +// GetTool returns a specific tool from a server +func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mcp.Tool, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + mcpTools, exists := h.GetTools(ctx)[serverName] + if !exists { + return nil, fmt.Errorf("no connection found for server %s", serverName) + } else if mcpTools.Err != nil { + return nil, mcpTools.Err + } + + for _, tool := range mcpTools.Tools { + if tool.Name == toolName { + return &tool, nil + } + } + + return nil, fmt.Errorf("tool %s not found", toolName) +} + +// handleToolError handles tool execution errors +func handleToolError(result *mcp.CallToolResult) error { + if !result.IsError { + return nil + } + if len(result.Content) > 0 { + return fmt.Errorf("tool error: %v", result.Content[0]) + } + return fmt.Errorf("tool error: unknown error") +} + +// InvokeTool calls a tool with the given arguments +func (h *MCPHost) InvokeTool(ctx context.Context, + serverName, toolName string, arguments map[string]any, +) (*mcp.CallToolResult, error) { + log.Info().Str("tool", toolName).Interface("args", arguments). + Str("server", serverName).Msg("invoke tool") + + conn, err := h.GetClient(serverName) + if err != nil { + return nil, errors.Wrapf(err, + "get mcp client for server %s failed", serverName) + } + + mcpTool, err := h.GetTool(ctx, serverName, toolName) + if err != nil { + return nil, errors.Wrapf(err, + "get mcp tool %s/%s failed", serverName, toolName) + } + + req := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Name: mcpTool.Name, + Arguments: arguments, + }, + } + + result, err := conn.CallTool(ctx, req) + if err != nil { + return nil, errors.Wrapf(err, + "call tool %s/%s failed", serverName, toolName) + } + + if err := handleToolError(result); err != nil { + return nil, err + } + + return result, nil +} + +// CloseServers closes all connected MCP servers +func (h *MCPHost) CloseServers() error { + h.mu.Lock() + defer h.mu.Unlock() + + log.Info().Msg("Shutting down MCP servers...") + for name, conn := range h.connections { + if err := conn.Client.Close(); err != nil { + log.Error().Str("name", name).Err(err).Msg("Failed to close server") + } else { + delete(h.connections, name) + log.Info().Str("name", name).Msg("Server closed") + } + } + + return nil +} diff --git a/pkg/mcphost/host_test.go b/pkg/mcphost/host_test.go new file mode 100644 index 00000000..b581e55e --- /dev/null +++ b/pkg/mcphost/host_test.go @@ -0,0 +1,228 @@ +package mcphost + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMCPHost(t *testing.T) { + // Test with valid config file + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + assert.NotNil(t, host) + assert.NotNil(t, host.config) + assert.NotEmpty(t, host.config.MCPServers) + + // Test with non-existent config file + host, err = NewMCPHost("./testdata/non_existent.json") + require.Error(t, err, "expected error when config file does not exist") + assert.Nil(t, host) +} + +func TestInitServers(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Verify connections are established + assert.Equal(t, 2, len(host.connections)) + assert.Contains(t, host.connections, "filesystem") + assert.Contains(t, host.connections, "weather") +} + +func TestGetClient(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Test getting existing client + client, err := host.GetClient("weather") + require.NoError(t, err) + assert.NotNil(t, client) + + // Test getting non-existent client + client, err = host.GetClient("non_existent") + assert.Error(t, err) + assert.Nil(t, client) +} + +func TestGetTools(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + tools := host.GetTools(ctx) + assert.Equal(t, 2, len(tools)) + assert.Contains(t, tools, "weather") + assert.Contains(t, tools, "filesystem") + + // Verify weather tools + weatherTools := tools["weather"] + assert.NoError(t, weatherTools.Err) + assert.NotEmpty(t, weatherTools.Tools) + + // Check if get_alerts tool exists + found := false + for _, tool := range weatherTools.Tools { + if tool.Name == "get_alerts" { + found = true + break + } + } + assert.True(t, found, "get_alerts tool not found in weather tools") +} + +func TestGetTool(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Test getting existing tool + tool, err := host.GetTool(ctx, "weather", "get_alerts") + require.NoError(t, err) + assert.NotNil(t, tool) + assert.Equal(t, "get_alerts", tool.Name) + + // Test getting non-existent tool + tool, err = host.GetTool(ctx, "weather", "non_existent") + assert.Error(t, err) + assert.Nil(t, tool) + + // Test getting tool from non-existent server + tool, err = host.GetTool(ctx, "non_existent", "get_alerts") + assert.Error(t, err) + assert.Nil(t, tool) +} + +func TestInvokeTool(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Test invoking existing tool + result, err := host.InvokeTool(ctx, "weather", "get_alerts", + map[string]interface{}{"state": "CA"}, + ) + require.NoError(t, err) + assert.NotNil(t, result) + + // Test invoking non-existent tool + result, err = host.InvokeTool(ctx, "weather", "non_existent", + map[string]interface{}{"state": "CA"}, + ) + require.Error(t, err, "expected error when tool does not exist") + assert.Nil(t, result) + + // Test invoking tool with invalid arguments + result, err = host.InvokeTool(ctx, "weather", "get_alerts", + map[string]interface{}{"invalid_arg": "value"}, + ) + require.Error(t, err, "expected error when arguments are invalid") + assert.Nil(t, result) +} + +func TestCloseServers(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Verify servers are connected + assert.Equal(t, 2, len(host.connections)) + + // Close servers + err = host.CloseServers() + require.NoError(t, err) + + // Verify connections are closed + assert.Empty(t, host.connections) +} + +func TestConcurrentOperations(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Test concurrent tool invocations + done := make(chan bool) + timeout := time.After(10 * time.Second) // Increase timeout to 10 seconds + + for i := 0; i < 5; i++ { + go func() { + result, err := host.InvokeTool(ctx, "weather", "get_alerts", + map[string]interface{}{"state": "CA"}, + ) + assert.NoError(t, err) + assert.NotNil(t, result) + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 5; i++ { + select { + case <-done: + // Success + case <-timeout: + t.Fatal("Timeout waiting for concurrent operations") + } + } +} + +func TestDisabledServer(t *testing.T) { + host, err := NewMCPHost("./testdata/test.mcp.json") + require.NoError(t, err) + + ctx := context.Background() + err = host.InitServers(ctx) + require.NoError(t, err) + + // Verify only enabled servers are connected + assert.Equal(t, 2, len(host.connections)) + assert.Contains(t, host.connections, "filesystem") + assert.Contains(t, host.connections, "weather") + assert.NotContains(t, host.connections, "disabled_server") + + // Test getting disabled server + client, err := host.GetClient("disabled_server") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no connection found for server disabled_server") + assert.Nil(t, client) + + // Test getting tools from disabled server + tools := host.GetTools(ctx) + assert.Equal(t, 2, len(tools)) + assert.Contains(t, tools, "filesystem") + assert.Contains(t, tools, "weather") + assert.NotContains(t, tools, "disabled_server") + + // Test getting tool from disabled server + tool, err := host.GetTool(ctx, "disabled_server", "some_tool") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no connection found for server disabled_server") + assert.Nil(t, tool) +} diff --git a/pkg/mcphost/testdata/test.mcp.json b/pkg/mcphost/testdata/test.mcp.json index b6a9a947..e80e4d58 100644 --- a/pkg/mcphost/testdata/test.mcp.json +++ b/pkg/mcphost/testdata/test.mcp.json @@ -22,6 +22,11 @@ "env": { "ABC": "123" } + }, + "disabled_server": { + "command": "echo", + "args": ["disabled"], + "disabled": true } } }