refactor: move mcphost package to top level

This commit is contained in:
lilong.129
2025-05-17 11:55:26 +08:00
parent 5d8c22f729
commit e94dacb5b2
18 changed files with 13 additions and 13 deletions

View File

@@ -1,5 +0,0 @@
# mcphost
This package is a fork of [mark3labs/mcphost], and it helps HttpRunner to interact with external tools through the Model Context Protocol (MCP).
[mark3labs/mcphost]: https://github.com/mark3labs/mcphost

View File

@@ -1,416 +0,0 @@
package mcphost
import (
"context"
"fmt"
"os"
"strings"
"github.com/bytedance/sonic"
"github.com/charmbracelet/glamour"
"github.com/charmbracelet/glamour/styles"
"github.com/charmbracelet/huh"
"github.com/charmbracelet/huh/spinner"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/lipgloss/list"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"golang.org/x/term"
)
// NewChat creates a new chat session
func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat, error) {
// Get model config from environment variables
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeGPT)
if err != nil {
return nil, err
}
model, err := openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
if err != nil {
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
// Create markdown renderer
renderer, err := glamour.NewTermRenderer(
glamour.WithStandardStyle(styles.TokyoNightStyle),
glamour.WithWordWrap(getTerminalWidth()),
)
if err != nil {
return nil, errors.Wrap(err, "failed to create markdown renderer")
}
// Load system prompt from file if provided
systemPrompt := "chat to interact with MCP tools"
if systemPromptFile != "" {
customPrompt, err := loadSystemPrompt(systemPromptFile)
if err != nil {
return nil, errors.Wrap(err, "failed to load system prompt")
}
if customPrompt != "" {
systemPrompt = customPrompt
}
}
// convert MCP tools to eino tool infos
einoTools, err := h.GetEinoToolInfos(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get eino tool infos")
}
toolCallingModel, err := model.WithTools(einoTools)
if err != nil {
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
return &Chat{
model: toolCallingModel,
systemPrompt: systemPrompt,
history: ai.ConversationHistory{},
renderer: renderer,
host: h,
tools: einoTools,
}, nil
}
// Chat represents a chat session with LLM
type Chat struct {
model model.ToolCallingChatModel
systemPrompt string
history ai.ConversationHistory
renderer *glamour.TermRenderer
host *MCPHost
tools []*schema.ToolInfo
}
// Start starts the chat session
func (c *Chat) Start() error {
// Add system message
c.history = ai.ConversationHistory{
{
Role: schema.System,
Content: c.systemPrompt,
},
}
c.showWelcome()
for {
var input string
err := huh.NewForm(huh.NewGroup(huh.NewText().
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)").
Value(&input).
CharLimit(5000)),
).WithWidth(getTerminalWidth()).
WithTheme(huh.ThemeCharm()).
Run()
if err != nil {
// Check if it's a user abort (Ctrl+C)
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\nGoodbye!")
return nil // Exit cleanly
}
return err // Return other errors normally
}
if input == "" {
continue
}
// Handle commands
if strings.HasPrefix(input, "/") {
if err := c.handleCommand(input); err != nil {
log.Error().Err(err).Msg("failed to handle command")
}
continue
}
// run prompt with MCP tools
if err := c.runPrompt(input); err != nil {
log.Error().Err(err).Msg("run prompt error")
}
}
}
// runPrompt run prompt with MCP tools
func (c *Chat) runPrompt(prompt string) error {
fmt.Printf("\n%s\n", promptStyle.Render("You: "+prompt))
// Create user message
userMsg := &schema.Message{
Role: schema.User,
Content: prompt,
}
c.history = append(c.history, userMsg)
// Call LLM model to get response
ctx := context.Background()
var message *schema.Message
var modelErr error
_ = spinner.New().Title("Thinking...").Action(func() {
message, modelErr = c.model.Generate(ctx, c.history)
}).Run()
if modelErr != nil {
return modelErr
}
// Log usage statistics
if usage := message.ResponseMeta.Usage; usage != nil {
log.Debug().Int("input_tokens", usage.PromptTokens).
Int("output_tokens", usage.CompletionTokens).
Int("total_tokens", usage.TotalTokens).Msg("Usage statistics")
}
// Handle tool calls
toolCalls := message.ToolCalls
if len(toolCalls) > 0 {
return c.handleToolCalls(ctx, toolCalls)
}
// Add assistant's response to history
toolMsg := &schema.Message{
Role: schema.Assistant,
Content: message.Content,
}
c.history = append(c.history, toolMsg)
c.renderContent("Assistant", message.Content)
return nil
}
func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall) error {
for _, toolCall := range toolCalls {
serverToolName := toolCall.Function.Name
toolArgs := toolCall.Function.Arguments
log.Debug().Str("name", serverToolName).Str("args", toolArgs).Msg("handle tool call")
// Parse tool name
parts := strings.SplitN(serverToolName, "__", 2)
if len(parts) != 2 {
log.Error().Str("name", serverToolName).Msg("invalid tool name")
continue
}
serverName, toolName := parts[0], parts[1]
// Unmarshal tool arguments from JSON string
var argsMap map[string]interface{}
if err := sonic.UnmarshalString(toolArgs, &argsMap); err != nil {
log.Error().Err(err).Str("args", toolArgs).Msg("failed to unmarshal tool arguments")
continue
}
// Invoke tool
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
if err != nil {
log.Error().Err(err).Msg("invoke tool failed")
continue
}
// Format tool result
resultStr := ""
if result != nil && len(result.Content) > 0 {
for _, item := range result.Content {
resultStr += fmt.Sprintf("%v\n", item)
}
} else {
resultStr = fmt.Sprintf("%+v", result)
}
c.renderContent("Tool result", resultStr)
// Add tool result to history
toolMsg := &schema.Message{
Role: schema.Tool,
Content: resultStr,
ToolCallID: toolCall.ID,
}
c.history = append(c.history, toolMsg)
}
return nil
}
// handleCommand handles commands
func (c *Chat) handleCommand(cmd string) error {
switch cmd {
case "/help":
c.showWelcome()
case "/tools":
c.showTools()
case "/history":
c.showHistory()
case "/clear":
c.clearHistory()
case "/quit":
fmt.Println("Goodbye!")
os.Exit(0)
default:
fmt.Printf("Unknown command: %s\n", cmd)
}
return nil
}
// showWelcome show welcome and help information
func (c *Chat) showWelcome() {
markdown := fmt.Sprintf(`# Welcome to HttpRunner MCPHost Chat!
## Available Commands
The following commands are available:
- **/help**: Show this help message
- **/tools**: List all available tools
- **/history**: Display conversation history
- **/clear**: Clear conversation history
- **/quit**: Exit the chat session
You can also press Ctrl+C at any time to quit.
## Configurations
- **system-prompt**: %s
- **mcp-config**: %s
`, c.systemPrompt, c.host.config.ConfigPath)
c.renderContent("", markdown)
}
func (c *Chat) showHistory() {
if len(c.history) <= 1 { // Only system message
fmt.Println("No conversation history yet.")
return
}
fmt.Println("\nConversation History:")
for _, msg := range c.history {
if msg.Role == schema.System {
continue
}
role := "You"
if msg.Role == schema.Assistant {
role = "Assistant"
}
c.renderContent(role, msg.Content)
}
}
func (c *Chat) clearHistory() {
// Keep only the system message
systemMsg := c.history[0]
c.history = ai.ConversationHistory{systemMsg}
fmt.Println("Conversation history cleared.")
}
func (c *Chat) showTools() {
if c.host == nil {
fmt.Println("No MCP host loaded.")
return
}
ctx := context.Background()
results := c.host.GetTools(ctx)
if len(results) == 0 {
fmt.Println("No MCP servers loaded.")
return
}
width := getTerminalWidth()
contentWidth := width - 12
l := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoPurple).MarginRight(1))
for _, serverTools := range results {
serverList := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoCyan).MarginRight(1))
if serverTools.Err != nil {
serverList.Item(contentStyle.Render(fmt.Sprintf("Error: %v", serverTools.Err)))
} else if len(serverTools.Tools) == 0 {
serverList.Item(contentStyle.Render("No tools available."))
} else {
for _, tool := range serverTools.Tools {
descStyle := lipgloss.NewStyle().Foreground(tokyoFg).Width(contentWidth).Align(lipgloss.Left)
toolDesc := list.New().EnumeratorStyle(
lipgloss.NewStyle().Foreground(tokyoGreen).MarginRight(1),
).Item(descStyle.Render(tool.Description))
serverList.Item(toolNameStyle.Render(tool.Name)).Item(toolDesc)
}
}
l.Item(serverTools.ServerName).Item(serverList)
}
containerStyle := lipgloss.NewStyle().Margin(2).Width(width)
fmt.Print("\n" + containerStyle.Render(l.String()) + "\n")
}
// Render and display content
func (c *Chat) renderContent(title, content string) {
output, err := c.renderer.Render(content)
if err != nil {
log.Error().Err(err).Msg("render content failed")
output = content
}
if title != "" {
title = title + ": "
}
fmt.Printf("\n%s", responseStyle.Render(title+output))
}
// loadSystemPrompt loads the system prompt from a JSON file
func loadSystemPrompt(filePath string) (string, error) {
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
}
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("error reading prompt file: %v", err)
}
// Read file content directly as prompt
return string(data), nil
}
func getTerminalWidth() int {
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
return 80 // Fallback width
}
return width - 20
}
var (
// Tokyo Night theme colors
tokyoPurple = lipgloss.Color("99") // #9d7cd8
tokyoCyan = lipgloss.Color("73") // #7dcfff
tokyoBlue = lipgloss.Color("111") // #7aa2f7
tokyoGreen = lipgloss.Color("120") // #73daca
tokyoRed = lipgloss.Color("203") // #f7768e
tokyoOrange = lipgloss.Color("215") // #ff9e64
tokyoFg = lipgloss.Color("189") // #c0caf5
tokyoGray = lipgloss.Color("237") // #3b4261
tokyoBg = lipgloss.Color("234") // #1a1b26
promptStyle = lipgloss.NewStyle().
Foreground(tokyoBlue).
PaddingLeft(2)
responseStyle = lipgloss.NewStyle().
Foreground(tokyoFg).
PaddingLeft(2)
errorStyle = lipgloss.NewStyle().
Foreground(tokyoRed).
Bold(true)
toolNameStyle = lipgloss.NewStyle().
Foreground(tokyoCyan).
Bold(true)
descriptionStyle = lipgloss.NewStyle().
Foreground(tokyoFg).
PaddingBottom(1)
contentStyle = lipgloss.NewStyle().
Background(tokyoBg).
PaddingLeft(4).
PaddingRight(4)
)

View File

@@ -1,50 +0,0 @@
package mcphost
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewChat(t *testing.T) {
systemPromptFile := "test_system_prompt.txt"
_ = os.WriteFile(systemPromptFile, []byte("You are a helpful assistant."), 0o644)
defer os.Remove(systemPromptFile)
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), systemPromptFile)
assert.NoError(t, err)
assert.NotNil(t, chat)
assert.NotEmpty(t, chat.systemPrompt)
assert.NotNil(t, chat.tools)
}
func TestRunPromptWithNoToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), "")
assert.NoError(t, err)
err = chat.runPrompt("hi")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
}
func TestRunPromptWithToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), "")
assert.NoError(t, err)
assert.True(t, len(chat.tools) > 0)
err = chat.runPrompt("what is the weather in CA")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
}

View File

@@ -1,127 +0,0 @@
package mcphost
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
const (
transportStdio = "stdio"
transportSSE = "sse"
)
// MCPConfig represents the configuration for MCP servers
type MCPConfig struct {
ConfigPath string `json:"-"`
MCPServers map[string]ServerConfigWrapper `json:"mcpServers"`
}
// ServerConfig is an interface for different types of server configurations
type ServerConfig interface {
GetType() string
IsDisabled() bool
}
// STDIOServerConfig represents configuration for a STDIO-based server
type STDIOServerConfig struct {
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env,omitempty"`
Disabled bool `json:"disabled,omitempty"`
}
func (s STDIOServerConfig) GetType() string {
return transportStdio
}
func (s STDIOServerConfig) IsDisabled() bool {
return s.Disabled
}
// SSEServerConfig represents configuration for an SSE-based server
type SSEServerConfig struct {
Url string `json:"url"`
Headers []string `json:"headers,omitempty"`
Disabled bool `json:"disabled,omitempty"`
}
func (s SSEServerConfig) GetType() string {
return transportSSE
}
func (s SSEServerConfig) IsDisabled() bool {
return s.Disabled
}
// ServerConfigWrapper is a wrapper for different types of server configurations
type ServerConfigWrapper struct {
Config ServerConfig
}
func (w *ServerConfigWrapper) UnmarshalJSON(data []byte) error {
var typeField struct {
Url string `json:"url"`
}
if err := json.Unmarshal(data, &typeField); err != nil {
return err
}
if typeField.Url != "" {
// If the URL field is present, treat it as an SSE server
var sse SSEServerConfig
if err := json.Unmarshal(data, &sse); err != nil {
return err
}
w.Config = sse
} else {
// Otherwise, treat it as a STDIOServerConfig
var stdio STDIOServerConfig
if err := json.Unmarshal(data, &stdio); err != nil {
return err
}
w.Config = stdio
}
return nil
}
func (w ServerConfigWrapper) MarshalJSON() ([]byte, error) {
return json.Marshal(w.Config)
}
// LoadMCPConfig loads the MCP configuration from the specified path or default location
func LoadMCPConfig(configPath string) (*MCPConfig, error) {
if configPath == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("error getting home directory: %w", err)
}
configPath = filepath.Join(homeDir, ".mcp.json")
}
configPath = os.ExpandEnv(configPath)
// Check if config file exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return nil, fmt.Errorf("config file does not exist: %s", configPath)
}
// Read existing config
configData, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf(
"error reading config file %s: %w",
configPath,
err,
)
}
var config MCPConfig
if err := json.Unmarshal(configData, &config); err != nil {
return nil, fmt.Errorf("error parsing config file: %w", err)
}
config.ConfigPath = configPath
return &config, nil
}

View File

@@ -1,30 +0,0 @@
package mcphost
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestLoadSettings(t *testing.T) {
// Load settings from test.mcp.json
settings, err := LoadMCPConfig("testdata/test.mcp.json")
if err != nil {
t.Fatalf("Failed to load settings: %v", err)
}
// Verify settings are loaded correctly
assert.NotNil(t, settings)
assert.Contains(t, settings.MCPServers, "filesystem")
assert.Contains(t, settings.MCPServers, "weather")
// Verify specific server configurations
filesystemConfig := settings.MCPServers["filesystem"].Config.(STDIOServerConfig)
assert.Equal(t, "npx", filesystemConfig.Command)
assert.Equal(t, []string{"-y", "@modelcontextprotocol/server-filesystem", "./"}, filesystemConfig.Args)
weatherConfig := settings.MCPServers["weather"].Config.(STDIOServerConfig)
assert.Equal(t, "uv", weatherConfig.Command)
assert.Equal(t, []string{"--directory", "/Users/debugtalk/MyProjects/HttpRunner-dev/httprunner/pkg/mcphost/testdata", "run", "demo_weather.py"}, weatherConfig.Args)
assert.Equal(t, map[string]string{"ABC": "123"}, weatherConfig.Env)
}

View File

@@ -1,185 +0,0 @@
package mcphost
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/rs/zerolog/log"
)
// MCPToolRecord represents a single tool record in the database
// Each record contains detailed information about a tool and its server
type MCPToolRecord struct {
ToolID string `json:"tool_id"` // Unique identifier for the tool record
ServerName string `json:"mcp_server"` // Name of the MCP server
ToolName string `json:"tool_name"` // Name of the tool
Description string `json:"description"` // Tool description
Parameters string `json:"parameters"` // Tool input parameters in JSON format
Returns string `json:"returns"` // Tool return value format in JSON format
CreatedAt time.Time `json:"created_at"` // Record creation time
LastUpdatedAt time.Time `json:"last_updated_at"` // Record last update time
}
// DocStringInfo contains the parsed information from a Python docstring
type DocStringInfo struct {
Description string
Parameters map[string]string
Returns map[string]string
}
// extractDocStringInfo extracts information from a Python docstring
// Example input:
// """Get weather alerts for a US state.
//
// Args:
// state: Two-letter US state code (e.g. CA, NY)
//
// Returns:
// alerts: List of active weather alerts for the specified state
// error: Error message if the request fails
// """
func extractDocStringInfo(docstring string) DocStringInfo {
info := DocStringInfo{
Parameters: make(map[string]string),
Returns: make(map[string]string),
}
// Find the Args and Returns sections
argsIndex := strings.Index(docstring, "Args:")
returnsIndex := strings.Index(docstring, "Returns:")
// Extract description (everything before Args)
if argsIndex != -1 {
info.Description = strings.TrimSpace(docstring[:argsIndex])
} else if returnsIndex != -1 {
info.Description = strings.TrimSpace(docstring[:returnsIndex])
} else {
info.Description = strings.TrimSpace(docstring)
return info
}
// Helper function to extract key-value pairs from a section
extractSection := func(content string) map[string]string {
result := make(map[string]string)
lines := strings.Split(content, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key != "" && value != "" {
result[key] = value
}
}
return result
}
// Extract Args section
if argsIndex != -1 {
endIndex := returnsIndex
if endIndex == -1 {
endIndex = len(docstring)
}
argsContent := docstring[argsIndex+len("Args:") : endIndex]
info.Parameters = extractSection(argsContent)
}
// Extract Returns section
if returnsIndex != -1 {
returnsContent := docstring[returnsIndex+len("Returns:"):]
info.Returns = extractSection(returnsContent)
}
return info
}
// ConvertToolsToRecords converts []MCPTools to a list of database records
func ConvertToolsToRecords(tools []MCPTools) []MCPToolRecord {
var records []MCPToolRecord
now := time.Now()
for _, mcpTools := range tools {
if mcpTools.Err != nil {
log.Error().Str("server", mcpTools.ServerName).Err(mcpTools.Err).Msg("skip tools conversion due to error")
continue
}
for _, tool := range mcpTools.Tools {
// Generate unique ID by combining server name and tool name
id := fmt.Sprintf("%s_%s", mcpTools.ServerName, tool.Name)
// Extract docstring information
info := extractDocStringInfo(tool.Description)
// Convert parameters and returns to JSON
paramsJSON, err := sonic.MarshalString(info.Parameters)
if err != nil {
log.Warn().Interface("params", info.Parameters).Err(err).Msg("failed to marshal parameters to JSON")
paramsJSON = "{}"
}
returnsJSON, err := sonic.MarshalString(info.Returns)
if err != nil {
log.Warn().Interface("returns", info.Returns).Err(err).Msg("failed to marshal returns to JSON")
returnsJSON = "{}"
}
record := MCPToolRecord{
ToolID: id,
ServerName: mcpTools.ServerName,
ToolName: tool.Name,
Description: info.Description,
Parameters: paramsJSON,
Returns: returnsJSON,
CreatedAt: now,
LastUpdatedAt: now,
}
records = append(records, record)
}
}
return records
}
// ExportToolsToJSON dumps MCP tools to JSON file
func (h *MCPHost) ExportToolsToJSON(ctx context.Context, dumpPath string) error {
// get all tools
tools := h.GetTools(ctx)
// convert to records
records := ConvertToolsToRecords(tools)
// convert to JSON
recordsJSON, err := sonic.MarshalIndent(records, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal records to JSON: %w", err)
}
// create output directory
outputDir := filepath.Dir(dumpPath)
if outputDir != "." {
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
}
// write to file
if err := os.WriteFile(dumpPath, []byte(recordsJSON), 0o644); err != nil {
return fmt.Errorf("failed to write records to file: %w", err)
}
log.Info().Str("path", dumpPath).Msg("Tools records exported successfully")
return nil
}

View File

@@ -1,237 +0,0 @@
package mcphost
import (
"context"
"encoding/json"
"os"
"testing"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConvertToolsToRecordsFromFile(t *testing.T) {
hub, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// use ExportToolsToJSON to dump tools to JSON file
err = hub.ExportToolsToJSON(context.Background(), "./tools_records.json")
require.NoError(t, err)
// read the exported JSON file
data, err := os.ReadFile("./tools_records.json")
require.NoError(t, err)
// parse the exported JSON data
var records []MCPToolRecord
err = json.Unmarshal(data, &records)
require.NoError(t, err)
// verify the number of records
assert.NotEmpty(t, records, "Exported records should not be empty")
t.Logf("Tools records written to ./tools_records.json")
}
func TestExtractDocStringInfo(t *testing.T) {
tests := []struct {
name string
docstring string
want DocStringInfo
}{
{
name: "complete docstring with args and returns",
docstring: `Get weather alerts for a US state.
Args:
state: Two-letter US state code (e.g. CA, NY)
Returns:
alerts: List of active weather alerts for the specified state
error: Error message if the request fails
`,
want: DocStringInfo{
Description: "Get weather alerts for a US state.",
Parameters: map[string]string{
"state": "Two-letter US state code (e.g. CA, NY)",
},
Returns: map[string]string{
"alerts": "List of active weather alerts for the specified state",
"error": "Error message if the request fails",
},
},
},
{
name: "docstring with only args",
docstring: `Do screen swipe action.
Args:
direction: swipe direction (up, down)
`,
want: DocStringInfo{
Description: "Do screen swipe action.",
Parameters: map[string]string{
"direction": "swipe direction (up, down)",
},
Returns: map[string]string{},
},
},
{
name: "docstring with only description",
docstring: "Simple tool with no parameters.",
want: DocStringInfo{
Description: "Simple tool with no parameters.",
Parameters: map[string]string{},
Returns: map[string]string{},
},
},
{
name: "docstring with multiple parameters",
docstring: `Perform complex operation.
Args:
param1: first parameter description
param2: second parameter description
param3: third parameter description
Returns:
result: operation result
`,
want: DocStringInfo{
Description: "Perform complex operation.",
Parameters: map[string]string{
"param1": "first parameter description",
"param2": "second parameter description",
"param3": "third parameter description",
},
Returns: map[string]string{
"result": "operation result",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractDocStringInfo(tt.docstring)
assert.Equal(t, tt.want.Description, got.Description)
assert.Equal(t, tt.want.Parameters, got.Parameters)
assert.Equal(t, tt.want.Returns, got.Returns)
})
}
}
func TestConvertToolsToRecords(t *testing.T) {
tests := []struct {
name string
tools []MCPTools
want []MCPToolRecord
}{
{
name: "convert weather tool",
tools: []MCPTools{
{
ServerName: "weather",
Tools: []mcp.Tool{
{
Name: "get_alerts",
Description: `Get weather alerts for a US state.
Args:
state: Two-letter US state code (e.g. CA, NY)
Returns:
alerts: List of active weather alerts for the specified state
error: Error message if the request fails
`,
},
},
},
},
want: []MCPToolRecord{
{
ToolID: "weather_get_alerts",
ServerName: "weather",
ToolName: "get_alerts",
Description: "Get weather alerts for a US state.",
Parameters: `{"state":"Two-letter US state code (e.g. CA, NY)"}`,
Returns: `{"alerts":"List of active weather alerts for the specified state","error":"Error message if the request fails"}`,
},
},
},
{
name: "convert multiple tools",
tools: []MCPTools{
{
ServerName: "ui",
Tools: []mcp.Tool{
{
Name: "swipe",
Description: `Do screen swipe action.
Args:
direction: swipe direction (up, down)
`,
},
{
Name: "tap",
Description: "Tap on screen at specified position.",
},
},
},
},
want: []MCPToolRecord{
{
ToolID: "ui_swipe",
ServerName: "ui",
ToolName: "swipe",
Description: "Do screen swipe action.",
Parameters: `{"direction":"swipe direction (up, down)"}`,
Returns: "{}",
},
{
ToolID: "ui_tap",
ServerName: "ui",
ToolName: "tap",
Description: "Tap on screen at specified position.",
Parameters: "{}",
Returns: "{}",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ConvertToolsToRecords(tt.tools)
// Compare each record
require.Equal(t, len(tt.want), len(got))
for i := range tt.want {
assert.Equal(t, tt.want[i].ToolID, got[i].ToolID)
assert.Equal(t, tt.want[i].ServerName, got[i].ServerName)
assert.Equal(t, tt.want[i].ToolName, got[i].ToolName)
assert.Equal(t, tt.want[i].Description, got[i].Description)
// Compare JSON content (ignoring whitespace differences)
var wantParams, gotParams, wantReturns, gotReturns map[string]string
require.NoError(t, json.Unmarshal([]byte(tt.want[i].Parameters), &wantParams))
require.NoError(t, json.Unmarshal([]byte(got[i].Parameters), &gotParams))
require.NoError(t, json.Unmarshal([]byte(tt.want[i].Returns), &wantReturns))
require.NoError(t, json.Unmarshal([]byte(got[i].Returns), &gotReturns))
assert.Equal(t, wantParams, gotParams)
assert.Equal(t, wantReturns, gotReturns)
// Verify timestamps are recent (within last 5 seconds)
now := time.Now()
assert.True(t, now.Sub(got[i].CreatedAt) < 5*time.Second, "CreatedAt should be recent")
assert.True(t, now.Sub(got[i].LastUpdatedAt) < 5*time.Second, "LastUpdatedAt should be recent")
// CreatedAt and LastUpdatedAt should be the same
assert.Equal(t, got[i].CreatedAt, got[i].LastUpdatedAt)
}
})
}
}

View File

@@ -1,375 +0,0 @@
package mcphost
import (
"bufio"
"context"
"fmt"
"io"
"os"
"strings"
"sync"
mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/internal/version"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
// MCPHost manages MCP server connections and tools
type MCPHost struct {
mu sync.RWMutex
connections map[string]*Connection
config *MCPConfig
}
// Connection represents a connection to an MCP server
type Connection struct {
Client client.MCPClient
Config ServerConfig
}
// MCPTools represents tools from a single MCP server
type MCPTools struct {
ServerName string
Tools []mcp.Tool
Err error
}
// NewMCPHost creates a new MCPHost instance
func NewMCPHost(configPath string) (*MCPHost, error) {
config, err := LoadMCPConfig(configPath)
if err != nil {
return nil, err
}
host := &MCPHost{
connections: make(map[string]*Connection),
config: config,
}
// Initialize MCP servers
if err := host.InitServers(context.Background()); err != nil {
return nil, fmt.Errorf("failed to initialize MCP servers: %w", err)
}
return host, nil
}
// InitServers initializes all MCP servers
func (h *MCPHost) InitServers(ctx context.Context) error {
for name, server := range h.config.MCPServers {
if server.Config.IsDisabled() {
continue
}
if err := h.connectToServer(ctx, name, server.Config); err != nil {
return fmt.Errorf("failed to connect to server %s: %w", name, err)
}
}
return nil
}
// connectToServer establishes connection to a single MCP server
func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config ServerConfig) error {
h.mu.Lock()
defer h.mu.Unlock()
log.Debug().Str("server", serverName).Msg("connecting to MCP server")
// Close existing connection if any
if existing, exists := h.connections[serverName]; exists {
if err := existing.Client.Close(); err != nil {
return fmt.Errorf("failed to close existing connection: %w", err)
}
delete(h.connections, serverName)
}
var mcpClient client.MCPClient
var err error
// create client based on server type
switch cfg := config.(type) {
case SSEServerConfig:
mcpClient, err = client.NewSSEMCPClient(cfg.Url,
client.WithHeaders(parseHeaders(cfg.Headers)))
case STDIOServerConfig:
env := make([]string, 0, len(cfg.Env))
for k, v := range cfg.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
mcpClient, err = client.NewStdioMCPClient(cfg.Command, env, cfg.Args...)
if stdioClient, ok := mcpClient.(*client.Client); ok {
stderr, _ := client.GetStderr(stdioClient)
startStdioLog(stderr, serverName)
}
default:
return fmt.Errorf("unsupported transport type: %s", config.GetType())
}
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}
// initialize client
_, err = mcpClient.Initialize(ctx, prepareClientInitRequest())
if err != nil {
mcpClient.Close()
return errors.Wrapf(err, "initialize MCP client for %s failed", serverName)
}
log.Info().Str("server", serverName).Msg("connected to MCP server")
h.connections[serverName] = &Connection{
Client: mcpClient,
Config: config,
}
return nil
}
// CloseServers closes all connected MCP servers
func (h *MCPHost) CloseServers() error {
h.mu.Lock()
defer h.mu.Unlock()
log.Info().Msg("Shutting down MCP servers...")
for name, conn := range h.connections {
if err := conn.Client.Close(); err != nil {
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
} else {
delete(h.connections, name)
log.Info().Str("name", name).Msg("Server closed")
}
}
return nil
}
// GetClient returns the client for the specified server
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, exists := h.connections[serverName]
if !exists {
return nil, fmt.Errorf("no connection found for server %s", serverName)
}
return conn.Client, nil
}
// GetTools returns all tools from all MCP servers
func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
h.mu.RLock()
defer h.mu.RUnlock()
var results []MCPTools
for serverName, conn := range h.connections {
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
log.Error().Err(err).Str("server", serverName).Msg("failed to get tools")
continue
}
results = append(results, MCPTools{
ServerName: serverName,
Tools: listResults.Tools,
Err: nil,
})
}
return results
}
// GetTool returns a specific tool from a server
func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mcp.Tool, error) {
h.mu.RLock()
defer h.mu.RUnlock()
// Get all tools
results := h.GetTools(ctx)
// Find the server's tools
var serverTools MCPTools
found := false
for _, tools := range results {
if tools.ServerName == serverName {
serverTools = tools
found = true
break
}
}
if !found {
return nil, fmt.Errorf("no connection found for server %s", serverName)
}
if serverTools.Err != nil {
return nil, serverTools.Err
}
// Find the specific tool
for _, tool := range serverTools.Tools {
if tool.Name == toolName {
return &tool, nil
}
}
return nil, fmt.Errorf("tool %s not found", toolName)
}
// InvokeTool calls a tool with the given arguments
func (h *MCPHost) InvokeTool(ctx context.Context,
serverName, toolName string, arguments map[string]any,
) (*mcp.CallToolResult, error) {
log.Info().Str("tool", toolName).Interface("args", arguments).
Str("server", serverName).Msg("invoke tool")
conn, err := h.GetClient(serverName)
if err != nil {
return nil, errors.Wrapf(err,
"get mcp client for server %s failed", serverName)
}
mcpTool, err := h.GetTool(ctx, serverName, toolName)
if err != nil {
return nil, errors.Wrapf(err,
"get mcp tool %s/%s failed", serverName, toolName)
}
req := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Meta *struct {
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
}{
Name: mcpTool.Name,
Arguments: arguments,
},
}
result, err := conn.CallTool(ctx, req)
if err != nil {
return nil, errors.Wrapf(err,
"call tool %s/%s failed", serverName, toolName)
}
if err := handleToolError(result); err != nil {
return nil, err
}
return result, nil
}
// GetEinoTool returns an eino tool for the given server and tool name
func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string) (tool.BaseTool, error) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, ok := h.connections[serverName]
if !ok {
return nil, fmt.Errorf("server not found: %s", serverName)
}
// get tools from MCP server and convert to eino tools
tools, err := mcpp.GetTools(ctx, &mcpp.Config{
Cli: conn.Client,
ToolNameList: []string{toolName},
})
if err != nil || len(tools) == 0 {
log.Error().Err(err).
Str("server", serverName).Str("tool", toolName).
Msg("get MCP tool failed")
return nil, err
}
return tools[0], nil
}
// GetEinoToolInfos convert MCP tools to eino tool infos
func (h *MCPHost) GetEinoToolInfos(ctx context.Context) ([]*schema.ToolInfo, error) {
results := h.GetTools(ctx)
if len(results) == 0 {
return nil, fmt.Errorf("no MCP servers loaded")
}
var tools []*schema.ToolInfo
for _, serverTools := range results {
if serverTools.Err != nil {
log.Error().Err(serverTools.Err).Str("server", serverTools.ServerName).Msg("failed to get tools")
continue
}
for _, tool := range serverTools.Tools {
einoTool, err := h.GetEinoTool(ctx, serverTools.ServerName, tool.Name)
if err != nil {
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool")
continue
}
einoToolInfo, err := einoTool.Info(ctx)
if err != nil {
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool info")
continue
}
einoToolInfo.Name = fmt.Sprintf("%s__%s", serverTools.ServerName, tool.Name)
tools = append(tools, einoToolInfo)
}
}
log.Info().Int("count", len(tools)).Msg("eino tool infos loaded")
return tools, nil
}
// parseHeaders parses header strings into a map
func parseHeaders(headerList []string) map[string]string {
headers := make(map[string]string)
for _, header := range headerList {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
return headers
}
// startStdioLog starts a goroutine to print stdio logs
func startStdioLog(stderr io.Reader, serverName string) {
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
}
}()
}
// prepareClientInitRequest creates a standard initialization request
func prepareClientInitRequest() mcp.InitializeRequest {
return mcp.InitializeRequest{
Params: struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities mcp.ClientCapabilities `json:"capabilities"`
ClientInfo mcp.Implementation `json:"clientInfo"`
}{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
Capabilities: mcp.ClientCapabilities{},
ClientInfo: mcp.Implementation{
Name: "hrp-mcphost",
Version: version.GetVersionInfo(),
},
},
}
}
// handleToolError handles tool execution errors
func handleToolError(result *mcp.CallToolResult) error {
if !result.IsError {
return nil
}
if len(result.Content) > 0 {
return fmt.Errorf("tool error: %v", result.Content[0])
}
return fmt.Errorf("tool error: unknown error")
}

View File

@@ -1,233 +0,0 @@
package mcphost
import (
"context"
"testing"
"time"
"github.com/cloudwego/eino/components/tool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMCPHost(t *testing.T) {
// Test with valid config file
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
assert.NotNil(t, host)
assert.NotNil(t, host.config)
assert.NotEmpty(t, host.config.MCPServers)
// Test with non-existent config file
host, err = NewMCPHost("./testdata/non_existent.json")
require.Error(t, err, "expected error when config file does not exist")
assert.Nil(t, host)
}
func TestInitServers(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// Verify connections are established
assert.Equal(t, 2, len(host.connections))
assert.Contains(t, host.connections, "filesystem")
assert.Contains(t, host.connections, "weather")
}
func TestGetClient(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// Test getting existing client
client, err := host.GetClient("weather")
require.NoError(t, err)
assert.NotNil(t, client)
// Test getting non-existent client
client, err = host.GetClient("non_existent")
assert.Error(t, err)
assert.Nil(t, client)
}
func TestGetTools(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
ctx := context.Background()
tools := host.GetTools(ctx)
assert.Equal(t, 2, len(tools))
// Verify weather tools
var weatherTools MCPTools
for _, tool := range tools {
if tool.ServerName == "weather" {
weatherTools = tool
break
}
}
assert.NoError(t, weatherTools.Err)
assert.NotEmpty(t, weatherTools.Tools)
// Check if get_alerts tool exists
found := false
for _, tool := range weatherTools.Tools {
if tool.Name == "get_alerts" {
found = true
break
}
}
assert.True(t, found, "get_alerts tool not found in weather tools")
}
func TestGetTool(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
ctx := context.Background()
// Test getting existing tool
tool, err := host.GetTool(ctx, "weather", "get_alerts")
require.NoError(t, err)
assert.NotNil(t, tool)
assert.Equal(t, "get_alerts", tool.Name)
// Test getting non-existent tool
tool, err = host.GetTool(ctx, "weather", "non_existent")
assert.Error(t, err)
assert.Nil(t, tool)
// Test getting tool from non-existent server
tool, err = host.GetTool(ctx, "non_existent", "get_alerts")
assert.Error(t, err)
assert.Nil(t, tool)
}
func TestInvokeTool(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
ctx := context.Background()
// Test invoking existing tool
result, err := host.InvokeTool(ctx, "weather", "get_alerts",
map[string]interface{}{"state": "CA"},
)
require.NoError(t, err)
assert.NotNil(t, result)
// Test invoking non-existent tool
result, err = host.InvokeTool(ctx, "weather", "non_existent",
map[string]interface{}{"state": "CA"},
)
require.Error(t, err, "expected error when tool does not exist")
assert.Nil(t, result)
// Test invoking tool with invalid arguments
result, err = host.InvokeTool(ctx, "weather", "get_alerts",
map[string]interface{}{"invalid_arg": "value"},
)
require.Error(t, err, "expected error when arguments are invalid")
assert.Nil(t, result)
}
func TestCallEinoTool(t *testing.T) {
hub, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
ctx := context.Background()
einoTool, err := hub.GetEinoTool(ctx, "weather", "get_alerts")
require.NoError(t, err)
t.Logf("Tool: %v", einoTool)
tool := einoTool.(tool.InvokableTool)
result, err := tool.InvokableRun(ctx, `{"state": "CA"}`)
require.NoError(t, err)
t.Logf("Result: %v", result)
}
func TestCloseServers(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// Verify servers are connected
assert.Equal(t, 2, len(host.connections))
// Close servers
err = host.CloseServers()
require.NoError(t, err)
// Verify connections are closed
assert.Empty(t, host.connections)
}
func TestConcurrentOperations(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// Test concurrent tool invocations
done := make(chan bool)
timeout := time.After(30 * time.Second) // Increase timeout to 30 seconds
for i := 0; i < 3; i++ { // Reduce number of concurrent operations to 3
go func() {
result, err := host.InvokeTool(
context.Background(), "weather", "get_alerts",
map[string]interface{}{"state": "CA"},
)
assert.NoError(t, err)
assert.NotNil(t, result)
done <- true
}()
}
// Wait for all goroutines to complete
for i := 0; i < 3; i++ { // Update loop count to match the number of goroutines
select {
case <-done:
// Success
case <-timeout:
t.Fatal("Timeout waiting for concurrent operations")
}
}
}
func TestDisabledServer(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// Verify only enabled servers are connected
assert.Equal(t, 2, len(host.connections))
assert.Contains(t, host.connections, "filesystem")
assert.Contains(t, host.connections, "weather")
assert.NotContains(t, host.connections, "disabled_server")
// Test getting disabled server
client, err := host.GetClient("disabled_server")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no connection found for server disabled_server")
assert.Nil(t, client)
// Test getting tools from disabled server
ctx := context.Background()
tools := host.GetTools(ctx)
assert.Equal(t, 2, len(tools))
// Verify enabled servers in tools list
var foundFilesystem, foundWeather bool
for _, serverTools := range tools {
if serverTools.ServerName == "filesystem" {
foundFilesystem = true
} else if serverTools.ServerName == "weather" {
foundWeather = true
}
}
assert.True(t, foundFilesystem, "filesystem server not found in tools")
assert.True(t, foundWeather, "weather server not found in tools")
// Test getting tool from disabled server
tool, err := host.GetTool(ctx, "disabled_server", "some_tool")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no connection found for server disabled_server")
assert.Nil(t, tool)
}

View File

@@ -1,94 +0,0 @@
from typing import Any
import httpx
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server
mcp = FastMCP("weather")
# Constants
NWS_API_BASE = "https://api.weather.gov"
USER_AGENT = "weather-app/1.0"
async def make_nws_request(url: str) -> dict[str, Any] | None:
"""Make a request to the NWS API with proper error handling."""
headers = {
"User-Agent": USER_AGENT,
"Accept": "application/geo+json"
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers, timeout=30.0)
response.raise_for_status()
return response.json()
except Exception:
return None
def format_alert(feature: dict) -> str:
"""Format an alert feature into a readable string."""
props = feature["properties"]
return f"""
Event: {props.get('event', 'Unknown')}
Area: {props.get('areaDesc', 'Unknown')}
Severity: {props.get('severity', 'Unknown')}
Description: {props.get('description', 'No description available')}
Instructions: {props.get('instruction', 'No specific instructions provided')}
"""
@mcp.tool()
async def get_alerts(state: str) -> str:
"""Get weather alerts for a US state.
Args:
state: Two-letter US state code (e.g. CA, NY)
"""
url = f"{NWS_API_BASE}/alerts/active/area/{state}"
data = await make_nws_request(url)
if not data or "features" not in data:
return "Unable to fetch alerts or no alerts found."
if not data["features"]:
return "No active alerts for this state."
alerts = [format_alert(feature) for feature in data["features"]]
return "\n---\n".join(alerts)
@mcp.tool()
async def get_forecast(latitude: float, longitude: float) -> str:
"""Get weather forecast for a location.
Args:
latitude: Latitude of the location
longitude: Longitude of the location
"""
# First get the forecast grid endpoint
points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}"
points_data = await make_nws_request(points_url)
if not points_data:
return "Unable to fetch forecast data for this location."
# Get the forecast URL from the points response
forecast_url = points_data["properties"]["forecast"]
forecast_data = await make_nws_request(forecast_url)
if not forecast_data:
return "Unable to fetch detailed forecast."
# Format the periods into a readable forecast
periods = forecast_data["properties"]["periods"]
forecasts = []
for period in periods[:5]: # Only show next 5 periods
forecast = f"""
{period['name']}:
Temperature: {period['temperature']}°{period['temperatureUnit']}
Wind: {period['windSpeed']} {period['windDirection']}
Forecast: {period['detailedForecast']}
"""
forecasts.append(forecast)
return "\n---\n".join(forecasts)
if __name__ == "__main__":
# Initialize and run the server
mcp.run(transport='stdio')

View File

@@ -1,32 +0,0 @@
{
"mcpServers": {
"filesystem": {
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
"./"
]
},
"weather": {
"args": [
"--directory",
"/Users/debugtalk/MyProjects/HttpRunner-dev/httprunner/pkg/mcphost/testdata",
"run",
"demo_weather.py"
],
"autoApprove": [
"get_forecast"
],
"command": "uv",
"env": {
"ABC": "123"
}
},
"disabled_server": {
"command": "echo",
"args": ["disabled"],
"disabled": true
}
}
}