mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-21 16:23:16 +08:00
feat: add mcp host, load/invoke mcp tools
This commit is contained in:
298
pkg/mcphost/host.go
Normal file
298
pkg/mcphost/host.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/httprunner/httprunner/v5/internal/version"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCPTools represents tools from a single MCP server
|
||||
type MCPTools struct {
|
||||
Name string
|
||||
Tools []mcp.Tool
|
||||
Err error
|
||||
}
|
||||
|
||||
// MCPHost manages MCP server connections and tools
|
||||
type MCPHost struct {
|
||||
mu sync.RWMutex
|
||||
connections map[string]*Connection
|
||||
config *MCPConfig
|
||||
}
|
||||
|
||||
// Connection represents a connection to an MCP server
|
||||
type Connection struct {
|
||||
Client client.MCPClient
|
||||
Config ServerConfig
|
||||
}
|
||||
|
||||
// NewMCPHost creates a new MCPHost instance
|
||||
func NewMCPHost(configPath string) (*MCPHost, error) {
|
||||
config, err := LoadMCPConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MCPHost{
|
||||
connections: make(map[string]*Connection),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseHeaders parses header strings into a map
|
||||
func parseHeaders(headerList []string) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range headerList {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// startStdioLog starts a goroutine to print stdio logs
|
||||
func startStdioLog(stderr io.Reader, serverName string) {
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// prepareClientInitRequest creates a standard initialization request
|
||||
func prepareClientInitRequest() mcp.InitializeRequest {
|
||||
return mcp.InitializeRequest{
|
||||
Params: struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities mcp.ClientCapabilities `json:"capabilities"`
|
||||
ClientInfo mcp.Implementation `json:"clientInfo"`
|
||||
}{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
Capabilities: mcp.ClientCapabilities{},
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "hrp-mcphost",
|
||||
Version: version.GetVersionInfo(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// InitServers initializes all MCP servers
|
||||
func (h *MCPHost) InitServers(ctx context.Context) error {
|
||||
for name, server := range h.config.MCPServers {
|
||||
if server.Config.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.connectToServer(ctx, name, server.Config); err != nil {
|
||||
return fmt.Errorf("failed to connect to server %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClient returns the client for the specified server
|
||||
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, exists := h.connections[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
}
|
||||
|
||||
return conn.Client, nil
|
||||
}
|
||||
|
||||
// connectToServer establishes connection to a single MCP server
|
||||
func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config ServerConfig) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
log.Debug().Str("server", serverName).Msg("connecting to MCP server")
|
||||
|
||||
// Close existing connection if any
|
||||
if existing, exists := h.connections[serverName]; exists {
|
||||
if err := existing.Client.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close existing connection: %w", err)
|
||||
}
|
||||
delete(h.connections, serverName)
|
||||
}
|
||||
|
||||
var mcpClient client.MCPClient
|
||||
var err error
|
||||
|
||||
// create client based on server type
|
||||
switch cfg := config.(type) {
|
||||
case SSEServerConfig:
|
||||
mcpClient, err = client.NewSSEMCPClient(cfg.Url, client.WithHeaders(parseHeaders(cfg.Headers)))
|
||||
case STDIOServerConfig:
|
||||
env := make([]string, 0, len(cfg.Env))
|
||||
for k, v := range cfg.Env {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...)
|
||||
if stdioClient, ok := mcpClient.(*client.Client); ok {
|
||||
stderr, _ := client.GetStderr(stdioClient)
|
||||
startStdioLog(stderr, serverName)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported transport type: %s", config.GetType())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create client: %w", err)
|
||||
}
|
||||
|
||||
// initialize client
|
||||
_, err = mcpClient.Initialize(ctx, prepareClientInitRequest())
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return errors.Wrapf(err, "initialize MCP client for %s failed", serverName)
|
||||
}
|
||||
|
||||
log.Info().Str("server", serverName).Msg("connected to MCP server")
|
||||
h.connections[serverName] = &Connection{
|
||||
Client: mcpClient,
|
||||
Config: config,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTools fetches available tools from all connected MCP servers
|
||||
func (h *MCPHost) GetTools(ctx context.Context) map[string]MCPTools {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
results := make(map[string]MCPTools)
|
||||
|
||||
for serverName, conn := range h.connections {
|
||||
if conn.Config.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
results[serverName] = MCPTools{
|
||||
Name: serverName,
|
||||
Tools: nil,
|
||||
Err: fmt.Errorf("failed to get tools: %w", err),
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
results[serverName] = MCPTools{
|
||||
Name: serverName,
|
||||
Tools: listResults.Tools,
|
||||
Err: nil,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// GetTool returns a specific tool from a server
|
||||
func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mcp.Tool, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
mcpTools, exists := h.GetTools(ctx)[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
} else if mcpTools.Err != nil {
|
||||
return nil, mcpTools.Err
|
||||
}
|
||||
|
||||
for _, tool := range mcpTools.Tools {
|
||||
if tool.Name == toolName {
|
||||
return &tool, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tool %s not found", toolName)
|
||||
}
|
||||
|
||||
// handleToolError handles tool execution errors
|
||||
func handleToolError(result *mcp.CallToolResult) error {
|
||||
if !result.IsError {
|
||||
return nil
|
||||
}
|
||||
if len(result.Content) > 0 {
|
||||
return fmt.Errorf("tool error: %v", result.Content[0])
|
||||
}
|
||||
return fmt.Errorf("tool error: unknown error")
|
||||
}
|
||||
|
||||
// InvokeTool calls a tool with the given arguments
|
||||
func (h *MCPHost) InvokeTool(ctx context.Context,
|
||||
serverName, toolName string, arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
log.Info().Str("tool", toolName).Interface("args", arguments).
|
||||
Str("server", serverName).Msg("invoke tool")
|
||||
|
||||
conn, err := h.GetClient(serverName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"get mcp client for server %s failed", serverName)
|
||||
}
|
||||
|
||||
mcpTool, err := h.GetTool(ctx, serverName, toolName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"get mcp tool %s/%s failed", serverName, toolName)
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Meta *struct {
|
||||
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Name: mcpTool.Name,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := conn.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"call tool %s/%s failed", serverName, toolName)
|
||||
}
|
||||
|
||||
if err := handleToolError(result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CloseServers closes all connected MCP servers
|
||||
func (h *MCPHost) CloseServers() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
log.Info().Msg("Shutting down MCP servers...")
|
||||
for name, conn := range h.connections {
|
||||
if err := conn.Client.Close(); err != nil {
|
||||
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
|
||||
} else {
|
||||
delete(h.connections, name)
|
||||
log.Info().Str("name", name).Msg("Server closed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
228
pkg/mcphost/host_test.go
Normal file
228
pkg/mcphost/host_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMCPHost(t *testing.T) {
|
||||
// Test with valid config file
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, host)
|
||||
assert.NotNil(t, host.config)
|
||||
assert.NotEmpty(t, host.config.MCPServers)
|
||||
|
||||
// Test with non-existent config file
|
||||
host, err = NewMCPHost("./testdata/non_existent.json")
|
||||
require.Error(t, err, "expected error when config file does not exist")
|
||||
assert.Nil(t, host)
|
||||
}
|
||||
|
||||
func TestInitServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connections are established
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
assert.Contains(t, host.connections, "filesystem")
|
||||
assert.Contains(t, host.connections, "weather")
|
||||
}
|
||||
|
||||
func TestGetClient(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test getting existing client
|
||||
client, err := host.GetClient("weather")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// Test getting non-existent client
|
||||
client, err = host.GetClient("non_existent")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, client)
|
||||
}
|
||||
|
||||
func TestGetTools(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
assert.Contains(t, tools, "weather")
|
||||
assert.Contains(t, tools, "filesystem")
|
||||
|
||||
// Verify weather tools
|
||||
weatherTools := tools["weather"]
|
||||
assert.NoError(t, weatherTools.Err)
|
||||
assert.NotEmpty(t, weatherTools.Tools)
|
||||
|
||||
// Check if get_alerts tool exists
|
||||
found := false
|
||||
for _, tool := range weatherTools.Tools {
|
||||
if tool.Name == "get_alerts" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "get_alerts tool not found in weather tools")
|
||||
}
|
||||
|
||||
func TestGetTool(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test getting existing tool
|
||||
tool, err := host.GetTool(ctx, "weather", "get_alerts")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, tool)
|
||||
assert.Equal(t, "get_alerts", tool.Name)
|
||||
|
||||
// Test getting non-existent tool
|
||||
tool, err = host.GetTool(ctx, "weather", "non_existent")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tool)
|
||||
|
||||
// Test getting tool from non-existent server
|
||||
tool, err = host.GetTool(ctx, "non_existent", "get_alerts")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tool)
|
||||
}
|
||||
|
||||
func TestInvokeTool(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test invoking existing tool
|
||||
result, err := host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
|
||||
// Test invoking non-existent tool
|
||||
result, err = host.InvokeTool(ctx, "weather", "non_existent",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
require.Error(t, err, "expected error when tool does not exist")
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Test invoking tool with invalid arguments
|
||||
result, err = host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
map[string]interface{}{"invalid_arg": "value"},
|
||||
)
|
||||
require.Error(t, err, "expected error when arguments are invalid")
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCloseServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify servers are connected
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
|
||||
// Close servers
|
||||
err = host.CloseServers()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connections are closed
|
||||
assert.Empty(t, host.connections)
|
||||
}
|
||||
|
||||
func TestConcurrentOperations(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent tool invocations
|
||||
done := make(chan bool)
|
||||
timeout := time.After(10 * time.Second) // Increase timeout to 10 seconds
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
result, err := host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 5; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for concurrent operations")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledServer(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only enabled servers are connected
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
assert.Contains(t, host.connections, "filesystem")
|
||||
assert.Contains(t, host.connections, "weather")
|
||||
assert.NotContains(t, host.connections, "disabled_server")
|
||||
|
||||
// Test getting disabled server
|
||||
client, err := host.GetClient("disabled_server")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no connection found for server disabled_server")
|
||||
assert.Nil(t, client)
|
||||
|
||||
// Test getting tools from disabled server
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
assert.Contains(t, tools, "filesystem")
|
||||
assert.Contains(t, tools, "weather")
|
||||
assert.NotContains(t, tools, "disabled_server")
|
||||
|
||||
// Test getting tool from disabled server
|
||||
tool, err := host.GetTool(ctx, "disabled_server", "some_tool")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no connection found for server disabled_server")
|
||||
assert.Nil(t, tool)
|
||||
}
|
||||
5
pkg/mcphost/testdata/test.mcp.json
vendored
5
pkg/mcphost/testdata/test.mcp.json
vendored
@@ -22,6 +22,11 @@
|
||||
"env": {
|
||||
"ABC": "123"
|
||||
}
|
||||
},
|
||||
"disabled_server": {
|
||||
"command": "echo",
|
||||
"args": ["disabled"],
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user