mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-11 18:11:21 +08:00
refactor: move mcphost package to top level
This commit is contained in:
@@ -1,5 +0,0 @@
|
||||
# mcphost
|
||||
|
||||
This package is a fork of [mark3labs/mcphost], and it helps HttpRunner to interact with external tools through the Model Context Protocol (MCP).
|
||||
|
||||
[mark3labs/mcphost]: https://github.com/mark3labs/mcphost
|
||||
@@ -1,416 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/styles"
|
||||
"github.com/charmbracelet/huh"
|
||||
"github.com/charmbracelet/huh/spinner"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/lipgloss/list"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/code"
|
||||
"github.com/httprunner/httprunner/v5/uixt/ai"
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// NewChat creates a new chat session
|
||||
func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat, error) {
|
||||
// Get model config from environment variables
|
||||
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeGPT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
|
||||
// Create markdown renderer
|
||||
renderer, err := glamour.NewTermRenderer(
|
||||
glamour.WithStandardStyle(styles.TokyoNightStyle),
|
||||
glamour.WithWordWrap(getTerminalWidth()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create markdown renderer")
|
||||
}
|
||||
|
||||
// Load system prompt from file if provided
|
||||
systemPrompt := "chat to interact with MCP tools"
|
||||
if systemPromptFile != "" {
|
||||
customPrompt, err := loadSystemPrompt(systemPromptFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load system prompt")
|
||||
}
|
||||
if customPrompt != "" {
|
||||
systemPrompt = customPrompt
|
||||
}
|
||||
}
|
||||
|
||||
// convert MCP tools to eino tool infos
|
||||
einoTools, err := h.GetEinoToolInfos(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get eino tool infos")
|
||||
}
|
||||
|
||||
toolCallingModel, err := model.WithTools(einoTools)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
|
||||
return &Chat{
|
||||
model: toolCallingModel,
|
||||
systemPrompt: systemPrompt,
|
||||
history: ai.ConversationHistory{},
|
||||
renderer: renderer,
|
||||
host: h,
|
||||
tools: einoTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat represents a chat session with LLM
|
||||
type Chat struct {
|
||||
model model.ToolCallingChatModel
|
||||
systemPrompt string
|
||||
history ai.ConversationHistory
|
||||
renderer *glamour.TermRenderer
|
||||
host *MCPHost
|
||||
tools []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// Start starts the chat session
|
||||
func (c *Chat) Start() error {
|
||||
// Add system message
|
||||
c.history = ai.ConversationHistory{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: c.systemPrompt,
|
||||
},
|
||||
}
|
||||
|
||||
c.showWelcome()
|
||||
|
||||
for {
|
||||
var input string
|
||||
err := huh.NewForm(huh.NewGroup(huh.NewText().
|
||||
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)").
|
||||
Value(&input).
|
||||
CharLimit(5000)),
|
||||
).WithWidth(getTerminalWidth()).
|
||||
WithTheme(huh.ThemeCharm()).
|
||||
Run()
|
||||
if err != nil {
|
||||
// Check if it's a user abort (Ctrl+C)
|
||||
if errors.Is(err, huh.ErrUserAborted) {
|
||||
fmt.Println("\nGoodbye!")
|
||||
return nil // Exit cleanly
|
||||
}
|
||||
return err // Return other errors normally
|
||||
}
|
||||
|
||||
if input == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle commands
|
||||
if strings.HasPrefix(input, "/") {
|
||||
if err := c.handleCommand(input); err != nil {
|
||||
log.Error().Err(err).Msg("failed to handle command")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// run prompt with MCP tools
|
||||
if err := c.runPrompt(input); err != nil {
|
||||
log.Error().Err(err).Msg("run prompt error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runPrompt run prompt with MCP tools
|
||||
func (c *Chat) runPrompt(prompt string) error {
|
||||
fmt.Printf("\n%s\n", promptStyle.Render("You: "+prompt))
|
||||
|
||||
// Create user message
|
||||
userMsg := &schema.Message{
|
||||
Role: schema.User,
|
||||
Content: prompt,
|
||||
}
|
||||
c.history = append(c.history, userMsg)
|
||||
|
||||
// Call LLM model to get response
|
||||
ctx := context.Background()
|
||||
var message *schema.Message
|
||||
var modelErr error
|
||||
_ = spinner.New().Title("Thinking...").Action(func() {
|
||||
message, modelErr = c.model.Generate(ctx, c.history)
|
||||
}).Run()
|
||||
if modelErr != nil {
|
||||
return modelErr
|
||||
}
|
||||
|
||||
// Log usage statistics
|
||||
if usage := message.ResponseMeta.Usage; usage != nil {
|
||||
log.Debug().Int("input_tokens", usage.PromptTokens).
|
||||
Int("output_tokens", usage.CompletionTokens).
|
||||
Int("total_tokens", usage.TotalTokens).Msg("Usage statistics")
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
toolCalls := message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
return c.handleToolCalls(ctx, toolCalls)
|
||||
}
|
||||
|
||||
// Add assistant's response to history
|
||||
toolMsg := &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: message.Content,
|
||||
}
|
||||
c.history = append(c.history, toolMsg)
|
||||
c.renderContent("Assistant", message.Content)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall) error {
|
||||
for _, toolCall := range toolCalls {
|
||||
serverToolName := toolCall.Function.Name
|
||||
toolArgs := toolCall.Function.Arguments
|
||||
log.Debug().Str("name", serverToolName).Str("args", toolArgs).Msg("handle tool call")
|
||||
|
||||
// Parse tool name
|
||||
parts := strings.SplitN(serverToolName, "__", 2)
|
||||
if len(parts) != 2 {
|
||||
log.Error().Str("name", serverToolName).Msg("invalid tool name")
|
||||
continue
|
||||
}
|
||||
serverName, toolName := parts[0], parts[1]
|
||||
|
||||
// Unmarshal tool arguments from JSON string
|
||||
var argsMap map[string]interface{}
|
||||
if err := sonic.UnmarshalString(toolArgs, &argsMap); err != nil {
|
||||
log.Error().Err(err).Str("args", toolArgs).Msg("failed to unmarshal tool arguments")
|
||||
continue
|
||||
}
|
||||
|
||||
// Invoke tool
|
||||
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invoke tool failed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Format tool result
|
||||
resultStr := ""
|
||||
if result != nil && len(result.Content) > 0 {
|
||||
for _, item := range result.Content {
|
||||
resultStr += fmt.Sprintf("%v\n", item)
|
||||
}
|
||||
} else {
|
||||
resultStr = fmt.Sprintf("%+v", result)
|
||||
}
|
||||
c.renderContent("Tool result", resultStr)
|
||||
|
||||
// Add tool result to history
|
||||
toolMsg := &schema.Message{
|
||||
Role: schema.Tool,
|
||||
Content: resultStr,
|
||||
ToolCallID: toolCall.ID,
|
||||
}
|
||||
c.history = append(c.history, toolMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCommand handles commands
|
||||
func (c *Chat) handleCommand(cmd string) error {
|
||||
switch cmd {
|
||||
case "/help":
|
||||
c.showWelcome()
|
||||
case "/tools":
|
||||
c.showTools()
|
||||
case "/history":
|
||||
c.showHistory()
|
||||
case "/clear":
|
||||
c.clearHistory()
|
||||
case "/quit":
|
||||
fmt.Println("Goodbye!")
|
||||
os.Exit(0)
|
||||
default:
|
||||
fmt.Printf("Unknown command: %s\n", cmd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// showWelcome show welcome and help information
|
||||
func (c *Chat) showWelcome() {
|
||||
markdown := fmt.Sprintf(`# Welcome to HttpRunner MCPHost Chat!
|
||||
|
||||
## Available Commands
|
||||
|
||||
The following commands are available:
|
||||
|
||||
- **/help**: Show this help message
|
||||
- **/tools**: List all available tools
|
||||
- **/history**: Display conversation history
|
||||
- **/clear**: Clear conversation history
|
||||
- **/quit**: Exit the chat session
|
||||
|
||||
You can also press Ctrl+C at any time to quit.
|
||||
|
||||
## Configurations
|
||||
|
||||
- **system-prompt**: %s
|
||||
- **mcp-config**: %s
|
||||
`, c.systemPrompt, c.host.config.ConfigPath)
|
||||
|
||||
c.renderContent("", markdown)
|
||||
}
|
||||
|
||||
func (c *Chat) showHistory() {
|
||||
if len(c.history) <= 1 { // Only system message
|
||||
fmt.Println("No conversation history yet.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\nConversation History:")
|
||||
for _, msg := range c.history {
|
||||
if msg.Role == schema.System {
|
||||
continue
|
||||
}
|
||||
|
||||
role := "You"
|
||||
if msg.Role == schema.Assistant {
|
||||
role = "Assistant"
|
||||
}
|
||||
c.renderContent(role, msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) clearHistory() {
|
||||
// Keep only the system message
|
||||
systemMsg := c.history[0]
|
||||
c.history = ai.ConversationHistory{systemMsg}
|
||||
fmt.Println("Conversation history cleared.")
|
||||
}
|
||||
|
||||
func (c *Chat) showTools() {
|
||||
if c.host == nil {
|
||||
fmt.Println("No MCP host loaded.")
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
results := c.host.GetTools(ctx)
|
||||
if len(results) == 0 {
|
||||
fmt.Println("No MCP servers loaded.")
|
||||
return
|
||||
}
|
||||
width := getTerminalWidth()
|
||||
contentWidth := width - 12
|
||||
l := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoPurple).MarginRight(1))
|
||||
for _, serverTools := range results {
|
||||
serverList := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoCyan).MarginRight(1))
|
||||
if serverTools.Err != nil {
|
||||
serverList.Item(contentStyle.Render(fmt.Sprintf("Error: %v", serverTools.Err)))
|
||||
} else if len(serverTools.Tools) == 0 {
|
||||
serverList.Item(contentStyle.Render("No tools available."))
|
||||
} else {
|
||||
for _, tool := range serverTools.Tools {
|
||||
descStyle := lipgloss.NewStyle().Foreground(tokyoFg).Width(contentWidth).Align(lipgloss.Left)
|
||||
toolDesc := list.New().EnumeratorStyle(
|
||||
lipgloss.NewStyle().Foreground(tokyoGreen).MarginRight(1),
|
||||
).Item(descStyle.Render(tool.Description))
|
||||
serverList.Item(toolNameStyle.Render(tool.Name)).Item(toolDesc)
|
||||
}
|
||||
}
|
||||
l.Item(serverTools.ServerName).Item(serverList)
|
||||
}
|
||||
containerStyle := lipgloss.NewStyle().Margin(2).Width(width)
|
||||
fmt.Print("\n" + containerStyle.Render(l.String()) + "\n")
|
||||
}
|
||||
|
||||
// Render and display content
|
||||
func (c *Chat) renderContent(title, content string) {
|
||||
output, err := c.renderer.Render(content)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("render content failed")
|
||||
output = content
|
||||
}
|
||||
if title != "" {
|
||||
title = title + ": "
|
||||
}
|
||||
fmt.Printf("\n%s", responseStyle.Render(title+output))
|
||||
}
|
||||
|
||||
// loadSystemPrompt loads the system prompt from a JSON file
|
||||
func loadSystemPrompt(filePath string) (string, error) {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading prompt file: %v", err)
|
||||
}
|
||||
|
||||
// Read file content directly as prompt
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func getTerminalWidth() int {
|
||||
width, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
return 80 // Fallback width
|
||||
}
|
||||
return width - 20
|
||||
}
|
||||
|
||||
var (
|
||||
// Tokyo Night theme colors
|
||||
tokyoPurple = lipgloss.Color("99") // #9d7cd8
|
||||
tokyoCyan = lipgloss.Color("73") // #7dcfff
|
||||
tokyoBlue = lipgloss.Color("111") // #7aa2f7
|
||||
tokyoGreen = lipgloss.Color("120") // #73daca
|
||||
tokyoRed = lipgloss.Color("203") // #f7768e
|
||||
tokyoOrange = lipgloss.Color("215") // #ff9e64
|
||||
tokyoFg = lipgloss.Color("189") // #c0caf5
|
||||
tokyoGray = lipgloss.Color("237") // #3b4261
|
||||
tokyoBg = lipgloss.Color("234") // #1a1b26
|
||||
|
||||
promptStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoBlue).
|
||||
PaddingLeft(2)
|
||||
|
||||
responseStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoFg).
|
||||
PaddingLeft(2)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoRed).
|
||||
Bold(true)
|
||||
|
||||
toolNameStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoCyan).
|
||||
Bold(true)
|
||||
|
||||
descriptionStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoFg).
|
||||
PaddingBottom(1)
|
||||
|
||||
contentStyle = lipgloss.NewStyle().
|
||||
Background(tokyoBg).
|
||||
PaddingLeft(4).
|
||||
PaddingRight(4)
|
||||
)
|
||||
@@ -1,50 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewChat(t *testing.T) {
|
||||
systemPromptFile := "test_system_prompt.txt"
|
||||
_ = os.WriteFile(systemPromptFile, []byte("You are a helpful assistant."), 0o644)
|
||||
defer os.Remove(systemPromptFile)
|
||||
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), systemPromptFile)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, chat)
|
||||
assert.NotEmpty(t, chat.systemPrompt)
|
||||
assert.NotNil(t, chat.tools)
|
||||
}
|
||||
|
||||
func TestRunPromptWithNoToolCall(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = chat.runPrompt("hi")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.history) > 1)
|
||||
}
|
||||
|
||||
func TestRunPromptWithToolCall(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.tools) > 0)
|
||||
|
||||
err = chat.runPrompt("what is the weather in CA")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.history) > 1)
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const (
|
||||
transportStdio = "stdio"
|
||||
transportSSE = "sse"
|
||||
)
|
||||
|
||||
// MCPConfig represents the configuration for MCP servers
|
||||
type MCPConfig struct {
|
||||
ConfigPath string `json:"-"`
|
||||
MCPServers map[string]ServerConfigWrapper `json:"mcpServers"`
|
||||
}
|
||||
|
||||
// ServerConfig is an interface for different types of server configurations
|
||||
type ServerConfig interface {
|
||||
GetType() string
|
||||
IsDisabled() bool
|
||||
}
|
||||
|
||||
// STDIOServerConfig represents configuration for a STDIO-based server
|
||||
type STDIOServerConfig struct {
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
func (s STDIOServerConfig) GetType() string {
|
||||
return transportStdio
|
||||
}
|
||||
|
||||
func (s STDIOServerConfig) IsDisabled() bool {
|
||||
return s.Disabled
|
||||
}
|
||||
|
||||
// SSEServerConfig represents configuration for an SSE-based server
|
||||
type SSEServerConfig struct {
|
||||
Url string `json:"url"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
func (s SSEServerConfig) GetType() string {
|
||||
return transportSSE
|
||||
}
|
||||
|
||||
func (s SSEServerConfig) IsDisabled() bool {
|
||||
return s.Disabled
|
||||
}
|
||||
|
||||
// ServerConfigWrapper is a wrapper for different types of server configurations
|
||||
type ServerConfigWrapper struct {
|
||||
Config ServerConfig
|
||||
}
|
||||
|
||||
func (w *ServerConfigWrapper) UnmarshalJSON(data []byte) error {
|
||||
var typeField struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &typeField); err != nil {
|
||||
return err
|
||||
}
|
||||
if typeField.Url != "" {
|
||||
// If the URL field is present, treat it as an SSE server
|
||||
var sse SSEServerConfig
|
||||
if err := json.Unmarshal(data, &sse); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Config = sse
|
||||
} else {
|
||||
// Otherwise, treat it as a STDIOServerConfig
|
||||
var stdio STDIOServerConfig
|
||||
if err := json.Unmarshal(data, &stdio); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Config = stdio
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w ServerConfigWrapper) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(w.Config)
|
||||
}
|
||||
|
||||
// LoadMCPConfig loads the MCP configuration from the specified path or default location
|
||||
func LoadMCPConfig(configPath string) (*MCPConfig, error) {
|
||||
if configPath == "" {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting home directory: %w", err)
|
||||
}
|
||||
configPath = filepath.Join(homeDir, ".mcp.json")
|
||||
}
|
||||
configPath = os.ExpandEnv(configPath)
|
||||
|
||||
// Check if config file exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("config file does not exist: %s", configPath)
|
||||
}
|
||||
|
||||
// Read existing config
|
||||
configData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error reading config file %s: %w",
|
||||
configPath,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
var config MCPConfig
|
||||
if err := json.Unmarshal(configData, &config); err != nil {
|
||||
return nil, fmt.Errorf("error parsing config file: %w", err)
|
||||
}
|
||||
config.ConfigPath = configPath
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadSettings(t *testing.T) {
|
||||
// Load settings from test.mcp.json
|
||||
settings, err := LoadMCPConfig("testdata/test.mcp.json")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load settings: %v", err)
|
||||
}
|
||||
|
||||
// Verify settings are loaded correctly
|
||||
assert.NotNil(t, settings)
|
||||
assert.Contains(t, settings.MCPServers, "filesystem")
|
||||
assert.Contains(t, settings.MCPServers, "weather")
|
||||
|
||||
// Verify specific server configurations
|
||||
filesystemConfig := settings.MCPServers["filesystem"].Config.(STDIOServerConfig)
|
||||
assert.Equal(t, "npx", filesystemConfig.Command)
|
||||
assert.Equal(t, []string{"-y", "@modelcontextprotocol/server-filesystem", "./"}, filesystemConfig.Args)
|
||||
|
||||
weatherConfig := settings.MCPServers["weather"].Config.(STDIOServerConfig)
|
||||
assert.Equal(t, "uv", weatherConfig.Command)
|
||||
assert.Equal(t, []string{"--directory", "/Users/debugtalk/MyProjects/HttpRunner-dev/httprunner/pkg/mcphost/testdata", "run", "demo_weather.py"}, weatherConfig.Args)
|
||||
assert.Equal(t, map[string]string{"ABC": "123"}, weatherConfig.Env)
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCPToolRecord represents a single tool record in the database
|
||||
// Each record contains detailed information about a tool and its server
|
||||
type MCPToolRecord struct {
|
||||
ToolID string `json:"tool_id"` // Unique identifier for the tool record
|
||||
ServerName string `json:"mcp_server"` // Name of the MCP server
|
||||
ToolName string `json:"tool_name"` // Name of the tool
|
||||
Description string `json:"description"` // Tool description
|
||||
Parameters string `json:"parameters"` // Tool input parameters in JSON format
|
||||
Returns string `json:"returns"` // Tool return value format in JSON format
|
||||
CreatedAt time.Time `json:"created_at"` // Record creation time
|
||||
LastUpdatedAt time.Time `json:"last_updated_at"` // Record last update time
|
||||
}
|
||||
|
||||
// DocStringInfo contains the parsed information from a Python docstring
|
||||
type DocStringInfo struct {
|
||||
Description string
|
||||
Parameters map[string]string
|
||||
Returns map[string]string
|
||||
}
|
||||
|
||||
// extractDocStringInfo extracts information from a Python docstring
|
||||
// Example input:
|
||||
// """Get weather alerts for a US state.
|
||||
//
|
||||
// Args:
|
||||
// state: Two-letter US state code (e.g. CA, NY)
|
||||
//
|
||||
// Returns:
|
||||
// alerts: List of active weather alerts for the specified state
|
||||
// error: Error message if the request fails
|
||||
// """
|
||||
func extractDocStringInfo(docstring string) DocStringInfo {
|
||||
info := DocStringInfo{
|
||||
Parameters: make(map[string]string),
|
||||
Returns: make(map[string]string),
|
||||
}
|
||||
|
||||
// Find the Args and Returns sections
|
||||
argsIndex := strings.Index(docstring, "Args:")
|
||||
returnsIndex := strings.Index(docstring, "Returns:")
|
||||
|
||||
// Extract description (everything before Args)
|
||||
if argsIndex != -1 {
|
||||
info.Description = strings.TrimSpace(docstring[:argsIndex])
|
||||
} else if returnsIndex != -1 {
|
||||
info.Description = strings.TrimSpace(docstring[:returnsIndex])
|
||||
} else {
|
||||
info.Description = strings.TrimSpace(docstring)
|
||||
return info
|
||||
}
|
||||
|
||||
// Helper function to extract key-value pairs from a section
|
||||
extractSection := func(content string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
if key != "" && value != "" {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Extract Args section
|
||||
if argsIndex != -1 {
|
||||
endIndex := returnsIndex
|
||||
if endIndex == -1 {
|
||||
endIndex = len(docstring)
|
||||
}
|
||||
argsContent := docstring[argsIndex+len("Args:") : endIndex]
|
||||
info.Parameters = extractSection(argsContent)
|
||||
}
|
||||
|
||||
// Extract Returns section
|
||||
if returnsIndex != -1 {
|
||||
returnsContent := docstring[returnsIndex+len("Returns:"):]
|
||||
info.Returns = extractSection(returnsContent)
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// ConvertToolsToRecords converts []MCPTools to a list of database records
|
||||
func ConvertToolsToRecords(tools []MCPTools) []MCPToolRecord {
|
||||
var records []MCPToolRecord
|
||||
now := time.Now()
|
||||
|
||||
for _, mcpTools := range tools {
|
||||
if mcpTools.Err != nil {
|
||||
log.Error().Str("server", mcpTools.ServerName).Err(mcpTools.Err).Msg("skip tools conversion due to error")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tool := range mcpTools.Tools {
|
||||
// Generate unique ID by combining server name and tool name
|
||||
id := fmt.Sprintf("%s_%s", mcpTools.ServerName, tool.Name)
|
||||
|
||||
// Extract docstring information
|
||||
info := extractDocStringInfo(tool.Description)
|
||||
|
||||
// Convert parameters and returns to JSON
|
||||
paramsJSON, err := sonic.MarshalString(info.Parameters)
|
||||
if err != nil {
|
||||
log.Warn().Interface("params", info.Parameters).Err(err).Msg("failed to marshal parameters to JSON")
|
||||
paramsJSON = "{}"
|
||||
}
|
||||
|
||||
returnsJSON, err := sonic.MarshalString(info.Returns)
|
||||
if err != nil {
|
||||
log.Warn().Interface("returns", info.Returns).Err(err).Msg("failed to marshal returns to JSON")
|
||||
returnsJSON = "{}"
|
||||
}
|
||||
|
||||
record := MCPToolRecord{
|
||||
ToolID: id,
|
||||
ServerName: mcpTools.ServerName,
|
||||
ToolName: tool.Name,
|
||||
Description: info.Description,
|
||||
Parameters: paramsJSON,
|
||||
Returns: returnsJSON,
|
||||
CreatedAt: now,
|
||||
LastUpdatedAt: now,
|
||||
}
|
||||
|
||||
records = append(records, record)
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// ExportToolsToJSON dumps MCP tools to JSON file
|
||||
func (h *MCPHost) ExportToolsToJSON(ctx context.Context, dumpPath string) error {
|
||||
// get all tools
|
||||
tools := h.GetTools(ctx)
|
||||
// convert to records
|
||||
records := ConvertToolsToRecords(tools)
|
||||
// convert to JSON
|
||||
recordsJSON, err := sonic.MarshalIndent(records, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal records to JSON: %w", err)
|
||||
}
|
||||
// create output directory
|
||||
outputDir := filepath.Dir(dumpPath)
|
||||
if outputDir != "." {
|
||||
if err := os.MkdirAll(outputDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
}
|
||||
// write to file
|
||||
if err := os.WriteFile(dumpPath, []byte(recordsJSON), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write records to file: %w", err)
|
||||
}
|
||||
log.Info().Str("path", dumpPath).Msg("Tools records exported successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -1,237 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConvertToolsToRecordsFromFile(t *testing.T) {
|
||||
hub, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// use ExportToolsToJSON to dump tools to JSON file
|
||||
err = hub.ExportToolsToJSON(context.Background(), "./tools_records.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// read the exported JSON file
|
||||
data, err := os.ReadFile("./tools_records.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// parse the exported JSON data
|
||||
var records []MCPToolRecord
|
||||
err = json.Unmarshal(data, &records)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify the number of records
|
||||
assert.NotEmpty(t, records, "Exported records should not be empty")
|
||||
|
||||
t.Logf("Tools records written to ./tools_records.json")
|
||||
}
|
||||
|
||||
func TestExtractDocStringInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docstring string
|
||||
want DocStringInfo
|
||||
}{
|
||||
{
|
||||
name: "complete docstring with args and returns",
|
||||
docstring: `Get weather alerts for a US state.
|
||||
|
||||
Args:
|
||||
state: Two-letter US state code (e.g. CA, NY)
|
||||
|
||||
Returns:
|
||||
alerts: List of active weather alerts for the specified state
|
||||
error: Error message if the request fails
|
||||
`,
|
||||
want: DocStringInfo{
|
||||
Description: "Get weather alerts for a US state.",
|
||||
Parameters: map[string]string{
|
||||
"state": "Two-letter US state code (e.g. CA, NY)",
|
||||
},
|
||||
Returns: map[string]string{
|
||||
"alerts": "List of active weather alerts for the specified state",
|
||||
"error": "Error message if the request fails",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "docstring with only args",
|
||||
docstring: `Do screen swipe action.
|
||||
|
||||
Args:
|
||||
direction: swipe direction (up, down)
|
||||
`,
|
||||
want: DocStringInfo{
|
||||
Description: "Do screen swipe action.",
|
||||
Parameters: map[string]string{
|
||||
"direction": "swipe direction (up, down)",
|
||||
},
|
||||
Returns: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "docstring with only description",
|
||||
docstring: "Simple tool with no parameters.",
|
||||
want: DocStringInfo{
|
||||
Description: "Simple tool with no parameters.",
|
||||
Parameters: map[string]string{},
|
||||
Returns: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "docstring with multiple parameters",
|
||||
docstring: `Perform complex operation.
|
||||
|
||||
Args:
|
||||
param1: first parameter description
|
||||
param2: second parameter description
|
||||
param3: third parameter description
|
||||
|
||||
Returns:
|
||||
result: operation result
|
||||
`,
|
||||
want: DocStringInfo{
|
||||
Description: "Perform complex operation.",
|
||||
Parameters: map[string]string{
|
||||
"param1": "first parameter description",
|
||||
"param2": "second parameter description",
|
||||
"param3": "third parameter description",
|
||||
},
|
||||
Returns: map[string]string{
|
||||
"result": "operation result",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractDocStringInfo(tt.docstring)
|
||||
assert.Equal(t, tt.want.Description, got.Description)
|
||||
assert.Equal(t, tt.want.Parameters, got.Parameters)
|
||||
assert.Equal(t, tt.want.Returns, got.Returns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToolsToRecords(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []MCPTools
|
||||
want []MCPToolRecord
|
||||
}{
|
||||
{
|
||||
name: "convert weather tool",
|
||||
tools: []MCPTools{
|
||||
{
|
||||
ServerName: "weather",
|
||||
Tools: []mcp.Tool{
|
||||
{
|
||||
Name: "get_alerts",
|
||||
Description: `Get weather alerts for a US state.
|
||||
|
||||
Args:
|
||||
state: Two-letter US state code (e.g. CA, NY)
|
||||
|
||||
Returns:
|
||||
alerts: List of active weather alerts for the specified state
|
||||
error: Error message if the request fails
|
||||
`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []MCPToolRecord{
|
||||
{
|
||||
ToolID: "weather_get_alerts",
|
||||
ServerName: "weather",
|
||||
ToolName: "get_alerts",
|
||||
Description: "Get weather alerts for a US state.",
|
||||
Parameters: `{"state":"Two-letter US state code (e.g. CA, NY)"}`,
|
||||
Returns: `{"alerts":"List of active weather alerts for the specified state","error":"Error message if the request fails"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "convert multiple tools",
|
||||
tools: []MCPTools{
|
||||
{
|
||||
ServerName: "ui",
|
||||
Tools: []mcp.Tool{
|
||||
{
|
||||
Name: "swipe",
|
||||
Description: `Do screen swipe action.
|
||||
|
||||
Args:
|
||||
direction: swipe direction (up, down)
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "tap",
|
||||
Description: "Tap on screen at specified position.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []MCPToolRecord{
|
||||
{
|
||||
ToolID: "ui_swipe",
|
||||
ServerName: "ui",
|
||||
ToolName: "swipe",
|
||||
Description: "Do screen swipe action.",
|
||||
Parameters: `{"direction":"swipe direction (up, down)"}`,
|
||||
Returns: "{}",
|
||||
},
|
||||
{
|
||||
ToolID: "ui_tap",
|
||||
ServerName: "ui",
|
||||
ToolName: "tap",
|
||||
Description: "Tap on screen at specified position.",
|
||||
Parameters: "{}",
|
||||
Returns: "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ConvertToolsToRecords(tt.tools)
|
||||
|
||||
// Compare each record
|
||||
require.Equal(t, len(tt.want), len(got))
|
||||
for i := range tt.want {
|
||||
assert.Equal(t, tt.want[i].ToolID, got[i].ToolID)
|
||||
assert.Equal(t, tt.want[i].ServerName, got[i].ServerName)
|
||||
assert.Equal(t, tt.want[i].ToolName, got[i].ToolName)
|
||||
assert.Equal(t, tt.want[i].Description, got[i].Description)
|
||||
|
||||
// Compare JSON content (ignoring whitespace differences)
|
||||
var wantParams, gotParams, wantReturns, gotReturns map[string]string
|
||||
require.NoError(t, json.Unmarshal([]byte(tt.want[i].Parameters), &wantParams))
|
||||
require.NoError(t, json.Unmarshal([]byte(got[i].Parameters), &gotParams))
|
||||
require.NoError(t, json.Unmarshal([]byte(tt.want[i].Returns), &wantReturns))
|
||||
require.NoError(t, json.Unmarshal([]byte(got[i].Returns), &gotReturns))
|
||||
|
||||
assert.Equal(t, wantParams, gotParams)
|
||||
assert.Equal(t, wantReturns, gotReturns)
|
||||
|
||||
// Verify timestamps are recent (within last 5 seconds)
|
||||
now := time.Now()
|
||||
assert.True(t, now.Sub(got[i].CreatedAt) < 5*time.Second, "CreatedAt should be recent")
|
||||
assert.True(t, now.Sub(got[i].LastUpdatedAt) < 5*time.Second, "LastUpdatedAt should be recent")
|
||||
// CreatedAt and LastUpdatedAt should be the same
|
||||
assert.Equal(t, got[i].CreatedAt, got[i].LastUpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,375 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/internal/version"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCPHost manages MCP server connections and tools
|
||||
type MCPHost struct {
|
||||
mu sync.RWMutex
|
||||
connections map[string]*Connection
|
||||
config *MCPConfig
|
||||
}
|
||||
|
||||
// Connection represents a connection to an MCP server
|
||||
type Connection struct {
|
||||
Client client.MCPClient
|
||||
Config ServerConfig
|
||||
}
|
||||
|
||||
// MCPTools represents tools from a single MCP server
|
||||
type MCPTools struct {
|
||||
ServerName string
|
||||
Tools []mcp.Tool
|
||||
Err error
|
||||
}
|
||||
|
||||
// NewMCPHost creates a new MCPHost instance
|
||||
func NewMCPHost(configPath string) (*MCPHost, error) {
|
||||
config, err := LoadMCPConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host := &MCPHost{
|
||||
connections: make(map[string]*Connection),
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Initialize MCP servers
|
||||
if err := host.InitServers(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize MCP servers: %w", err)
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// InitServers initializes all MCP servers
|
||||
func (h *MCPHost) InitServers(ctx context.Context) error {
|
||||
for name, server := range h.config.MCPServers {
|
||||
if server.Config.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.connectToServer(ctx, name, server.Config); err != nil {
|
||||
return fmt.Errorf("failed to connect to server %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectToServer establishes connection to a single MCP server
|
||||
func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config ServerConfig) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
log.Debug().Str("server", serverName).Msg("connecting to MCP server")
|
||||
|
||||
// Close existing connection if any
|
||||
if existing, exists := h.connections[serverName]; exists {
|
||||
if err := existing.Client.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close existing connection: %w", err)
|
||||
}
|
||||
delete(h.connections, serverName)
|
||||
}
|
||||
|
||||
var mcpClient client.MCPClient
|
||||
var err error
|
||||
|
||||
// create client based on server type
|
||||
switch cfg := config.(type) {
|
||||
case SSEServerConfig:
|
||||
mcpClient, err = client.NewSSEMCPClient(cfg.Url,
|
||||
client.WithHeaders(parseHeaders(cfg.Headers)))
|
||||
case STDIOServerConfig:
|
||||
env := make([]string, 0, len(cfg.Env))
|
||||
for k, v := range cfg.Env {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...)
|
||||
if stdioClient, ok := mcpClient.(*client.Client); ok {
|
||||
stderr, _ := client.GetStderr(stdioClient)
|
||||
startStdioLog(stderr, serverName)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported transport type: %s", config.GetType())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create client: %w", err)
|
||||
}
|
||||
|
||||
// initialize client
|
||||
_, err = mcpClient.Initialize(ctx, prepareClientInitRequest())
|
||||
if err != nil {
|
||||
mcpClient.Close()
|
||||
return errors.Wrapf(err, "initialize MCP client for %s failed", serverName)
|
||||
}
|
||||
|
||||
log.Info().Str("server", serverName).Msg("connected to MCP server")
|
||||
h.connections[serverName] = &Connection{
|
||||
Client: mcpClient,
|
||||
Config: config,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseServers closes all connected MCP servers
|
||||
func (h *MCPHost) CloseServers() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
log.Info().Msg("Shutting down MCP servers...")
|
||||
for name, conn := range h.connections {
|
||||
if err := conn.Client.Close(); err != nil {
|
||||
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
|
||||
} else {
|
||||
delete(h.connections, name)
|
||||
log.Info().Str("name", name).Msg("Server closed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClient returns the client for the specified server
|
||||
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, exists := h.connections[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
}
|
||||
|
||||
return conn.Client, nil
|
||||
}
|
||||
|
||||
// GetTools returns all tools from all MCP servers
|
||||
func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
var results []MCPTools
|
||||
|
||||
for serverName, conn := range h.connections {
|
||||
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("server", serverName).Msg("failed to get tools")
|
||||
continue
|
||||
}
|
||||
|
||||
results = append(results, MCPTools{
|
||||
ServerName: serverName,
|
||||
Tools: listResults.Tools,
|
||||
Err: nil,
|
||||
})
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// GetTool returns a specific tool from a server
|
||||
func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mcp.Tool, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Get all tools
|
||||
results := h.GetTools(ctx)
|
||||
|
||||
// Find the server's tools
|
||||
var serverTools MCPTools
|
||||
found := false
|
||||
for _, tools := range results {
|
||||
if tools.ServerName == serverName {
|
||||
serverTools = tools
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
}
|
||||
if serverTools.Err != nil {
|
||||
return nil, serverTools.Err
|
||||
}
|
||||
|
||||
// Find the specific tool
|
||||
for _, tool := range serverTools.Tools {
|
||||
if tool.Name == toolName {
|
||||
return &tool, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tool %s not found", toolName)
|
||||
}
|
||||
|
||||
// InvokeTool calls a tool with the given arguments
|
||||
func (h *MCPHost) InvokeTool(ctx context.Context,
|
||||
serverName, toolName string, arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
log.Info().Str("tool", toolName).Interface("args", arguments).
|
||||
Str("server", serverName).Msg("invoke tool")
|
||||
|
||||
conn, err := h.GetClient(serverName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"get mcp client for server %s failed", serverName)
|
||||
}
|
||||
|
||||
mcpTool, err := h.GetTool(ctx, serverName, toolName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"get mcp tool %s/%s failed", serverName, toolName)
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Meta *struct {
|
||||
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Name: mcpTool.Name,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := conn.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err,
|
||||
"call tool %s/%s failed", serverName, toolName)
|
||||
}
|
||||
|
||||
if err := handleToolError(result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetEinoTool returns an eino tool for the given server and tool name
|
||||
func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string) (tool.BaseTool, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
conn, ok := h.connections[serverName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server not found: %s", serverName)
|
||||
}
|
||||
|
||||
// get tools from MCP server and convert to eino tools
|
||||
tools, err := mcpp.GetTools(ctx, &mcpp.Config{
|
||||
Cli: conn.Client,
|
||||
ToolNameList: []string{toolName},
|
||||
})
|
||||
if err != nil || len(tools) == 0 {
|
||||
log.Error().Err(err).
|
||||
Str("server", serverName).Str("tool", toolName).
|
||||
Msg("get MCP tool failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tools[0], nil
|
||||
}
|
||||
|
||||
// GetEinoToolInfos convert MCP tools to eino tool infos
|
||||
func (h *MCPHost) GetEinoToolInfos(ctx context.Context) ([]*schema.ToolInfo, error) {
|
||||
results := h.GetTools(ctx)
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
var tools []*schema.ToolInfo
|
||||
for _, serverTools := range results {
|
||||
if serverTools.Err != nil {
|
||||
log.Error().Err(serverTools.Err).Str("server", serverTools.ServerName).Msg("failed to get tools")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tool := range serverTools.Tools {
|
||||
einoTool, err := h.GetEinoTool(ctx, serverTools.ServerName, tool.Name)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool")
|
||||
continue
|
||||
}
|
||||
einoToolInfo, err := einoTool.Info(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool info")
|
||||
continue
|
||||
}
|
||||
einoToolInfo.Name = fmt.Sprintf("%s__%s", serverTools.ServerName, tool.Name)
|
||||
tools = append(tools, einoToolInfo)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Int("count", len(tools)).Msg("eino tool infos loaded")
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// parseHeaders parses header strings into a map
|
||||
func parseHeaders(headerList []string) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range headerList {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// startStdioLog starts a goroutine to print stdio logs
|
||||
func startStdioLog(stderr io.Reader, serverName string) {
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// prepareClientInitRequest creates a standard initialization request
|
||||
func prepareClientInitRequest() mcp.InitializeRequest {
|
||||
return mcp.InitializeRequest{
|
||||
Params: struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities mcp.ClientCapabilities `json:"capabilities"`
|
||||
ClientInfo mcp.Implementation `json:"clientInfo"`
|
||||
}{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
Capabilities: mcp.ClientCapabilities{},
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "hrp-mcphost",
|
||||
Version: version.GetVersionInfo(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleToolError handles tool execution errors
|
||||
func handleToolError(result *mcp.CallToolResult) error {
|
||||
if !result.IsError {
|
||||
return nil
|
||||
}
|
||||
if len(result.Content) > 0 {
|
||||
return fmt.Errorf("tool error: %v", result.Content[0])
|
||||
}
|
||||
return fmt.Errorf("tool error: unknown error")
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMCPHost(t *testing.T) {
|
||||
// Test with valid config file
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, host)
|
||||
assert.NotNil(t, host.config)
|
||||
assert.NotEmpty(t, host.config.MCPServers)
|
||||
|
||||
// Test with non-existent config file
|
||||
host, err = NewMCPHost("./testdata/non_existent.json")
|
||||
require.Error(t, err, "expected error when config file does not exist")
|
||||
assert.Nil(t, host)
|
||||
}
|
||||
|
||||
func TestInitServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connections are established
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
assert.Contains(t, host.connections, "filesystem")
|
||||
assert.Contains(t, host.connections, "weather")
|
||||
}
|
||||
|
||||
func TestGetClient(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test getting existing client
|
||||
client, err := host.GetClient("weather")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// Test getting non-existent client
|
||||
client, err = host.GetClient("non_existent")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, client)
|
||||
}
|
||||
|
||||
func TestGetTools(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
|
||||
// Verify weather tools
|
||||
var weatherTools MCPTools
|
||||
for _, tool := range tools {
|
||||
if tool.ServerName == "weather" {
|
||||
weatherTools = tool
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, weatherTools.Err)
|
||||
assert.NotEmpty(t, weatherTools.Tools)
|
||||
|
||||
// Check if get_alerts tool exists
|
||||
found := false
|
||||
for _, tool := range weatherTools.Tools {
|
||||
if tool.Name == "get_alerts" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "get_alerts tool not found in weather tools")
|
||||
}
|
||||
|
||||
func TestGetTool(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test getting existing tool
|
||||
tool, err := host.GetTool(ctx, "weather", "get_alerts")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, tool)
|
||||
assert.Equal(t, "get_alerts", tool.Name)
|
||||
|
||||
// Test getting non-existent tool
|
||||
tool, err = host.GetTool(ctx, "weather", "non_existent")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tool)
|
||||
|
||||
// Test getting tool from non-existent server
|
||||
tool, err = host.GetTool(ctx, "non_existent", "get_alerts")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tool)
|
||||
}
|
||||
|
||||
func TestInvokeTool(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test invoking existing tool
|
||||
result, err := host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
|
||||
// Test invoking non-existent tool
|
||||
result, err = host.InvokeTool(ctx, "weather", "non_existent",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
require.Error(t, err, "expected error when tool does not exist")
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Test invoking tool with invalid arguments
|
||||
result, err = host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
map[string]interface{}{"invalid_arg": "value"},
|
||||
)
|
||||
require.Error(t, err, "expected error when arguments are invalid")
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCallEinoTool(t *testing.T) {
|
||||
hub, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
einoTool, err := hub.GetEinoTool(ctx, "weather", "get_alerts")
|
||||
require.NoError(t, err)
|
||||
t.Logf("Tool: %v", einoTool)
|
||||
|
||||
tool := einoTool.(tool.InvokableTool)
|
||||
result, err := tool.InvokableRun(ctx, `{"state": "CA"}`)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Result: %v", result)
|
||||
}
|
||||
|
||||
func TestCloseServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify servers are connected
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
|
||||
// Close servers
|
||||
err = host.CloseServers()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connections are closed
|
||||
assert.Empty(t, host.connections)
|
||||
}
|
||||
|
||||
func TestConcurrentOperations(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent tool invocations
|
||||
done := make(chan bool)
|
||||
timeout := time.After(30 * time.Second) // Increase timeout to 30 seconds
|
||||
|
||||
for i := 0; i < 3; i++ { // Reduce number of concurrent operations to 3
|
||||
go func() {
|
||||
result, err := host.InvokeTool(
|
||||
context.Background(), "weather", "get_alerts",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 3; i++ { // Update loop count to match the number of goroutines
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for concurrent operations")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledServer(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only enabled servers are connected
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
assert.Contains(t, host.connections, "filesystem")
|
||||
assert.Contains(t, host.connections, "weather")
|
||||
assert.NotContains(t, host.connections, "disabled_server")
|
||||
|
||||
// Test getting disabled server
|
||||
client, err := host.GetClient("disabled_server")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no connection found for server disabled_server")
|
||||
assert.Nil(t, client)
|
||||
|
||||
// Test getting tools from disabled server
|
||||
ctx := context.Background()
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
|
||||
// Verify enabled servers in tools list
|
||||
var foundFilesystem, foundWeather bool
|
||||
for _, serverTools := range tools {
|
||||
if serverTools.ServerName == "filesystem" {
|
||||
foundFilesystem = true
|
||||
} else if serverTools.ServerName == "weather" {
|
||||
foundWeather = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundFilesystem, "filesystem server not found in tools")
|
||||
assert.True(t, foundWeather, "weather server not found in tools")
|
||||
|
||||
// Test getting tool from disabled server
|
||||
tool, err := host.GetTool(ctx, "disabled_server", "some_tool")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no connection found for server disabled_server")
|
||||
assert.Nil(t, tool)
|
||||
}
|
||||
94
pkg/mcphost/testdata/demo_weather.py
vendored
94
pkg/mcphost/testdata/demo_weather.py
vendored
@@ -1,94 +0,0 @@
|
||||
from typing import Any
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize FastMCP server
|
||||
mcp = FastMCP("weather")
|
||||
|
||||
# Constants
|
||||
NWS_API_BASE = "https://api.weather.gov"
|
||||
USER_AGENT = "weather-app/1.0"
|
||||
|
||||
async def make_nws_request(url: str) -> dict[str, Any] | None:
|
||||
"""Make a request to the NWS API with proper error handling."""
|
||||
headers = {
|
||||
"User-Agent": USER_AGENT,
|
||||
"Accept": "application/geo+json"
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def format_alert(feature: dict) -> str:
|
||||
"""Format an alert feature into a readable string."""
|
||||
props = feature["properties"]
|
||||
return f"""
|
||||
Event: {props.get('event', 'Unknown')}
|
||||
Area: {props.get('areaDesc', 'Unknown')}
|
||||
Severity: {props.get('severity', 'Unknown')}
|
||||
Description: {props.get('description', 'No description available')}
|
||||
Instructions: {props.get('instruction', 'No specific instructions provided')}
|
||||
"""
|
||||
|
||||
@mcp.tool()
|
||||
async def get_alerts(state: str) -> str:
|
||||
"""Get weather alerts for a US state.
|
||||
|
||||
Args:
|
||||
state: Two-letter US state code (e.g. CA, NY)
|
||||
"""
|
||||
url = f"{NWS_API_BASE}/alerts/active/area/{state}"
|
||||
data = await make_nws_request(url)
|
||||
|
||||
if not data or "features" not in data:
|
||||
return "Unable to fetch alerts or no alerts found."
|
||||
|
||||
if not data["features"]:
|
||||
return "No active alerts for this state."
|
||||
|
||||
alerts = [format_alert(feature) for feature in data["features"]]
|
||||
return "\n---\n".join(alerts)
|
||||
|
||||
@mcp.tool()
|
||||
async def get_forecast(latitude: float, longitude: float) -> str:
|
||||
"""Get weather forecast for a location.
|
||||
|
||||
Args:
|
||||
latitude: Latitude of the location
|
||||
longitude: Longitude of the location
|
||||
"""
|
||||
# First get the forecast grid endpoint
|
||||
points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}"
|
||||
points_data = await make_nws_request(points_url)
|
||||
|
||||
if not points_data:
|
||||
return "Unable to fetch forecast data for this location."
|
||||
|
||||
# Get the forecast URL from the points response
|
||||
forecast_url = points_data["properties"]["forecast"]
|
||||
forecast_data = await make_nws_request(forecast_url)
|
||||
|
||||
if not forecast_data:
|
||||
return "Unable to fetch detailed forecast."
|
||||
|
||||
# Format the periods into a readable forecast
|
||||
periods = forecast_data["properties"]["periods"]
|
||||
forecasts = []
|
||||
for period in periods[:5]: # Only show next 5 periods
|
||||
forecast = f"""
|
||||
{period['name']}:
|
||||
Temperature: {period['temperature']}°{period['temperatureUnit']}
|
||||
Wind: {period['windSpeed']} {period['windDirection']}
|
||||
Forecast: {period['detailedForecast']}
|
||||
"""
|
||||
forecasts.append(forecast)
|
||||
|
||||
return "\n---\n".join(forecasts)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
mcp.run(transport='stdio')
|
||||
32
pkg/mcphost/testdata/test.mcp.json
vendored
32
pkg/mcphost/testdata/test.mcp.json
vendored
@@ -1,32 +0,0 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"./"
|
||||
]
|
||||
},
|
||||
"weather": {
|
||||
"args": [
|
||||
"--directory",
|
||||
"/Users/debugtalk/MyProjects/HttpRunner-dev/httprunner/pkg/mcphost/testdata",
|
||||
"run",
|
||||
"demo_weather.py"
|
||||
],
|
||||
"autoApprove": [
|
||||
"get_forecast"
|
||||
],
|
||||
"command": "uv",
|
||||
"env": {
|
||||
"ABC": "123"
|
||||
}
|
||||
},
|
||||
"disabled_server": {
|
||||
"command": "echo",
|
||||
"args": ["disabled"],
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user