mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-10 17:43:00 +08:00
feat: chat with mcp tools
This commit is contained in:
@@ -11,39 +11,40 @@ import (
|
||||
// CmdMCPHost represents the mcphost command
|
||||
var CmdMCPHost = &cobra.Command{
|
||||
Use: "mcphost",
|
||||
Short: "Export MCP server tools to JSON description",
|
||||
Long: `Export all tools from MCP servers to JSON description.
|
||||
The tools will be exported with their descriptions, parameters, and return values.`,
|
||||
Short: "Start a chat session to interact with MCP tools",
|
||||
Long: `mcphost is a command-line tool that allows you to interact with MCP tools.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Create MCP host
|
||||
host, err := mcphost.NewMCPHost(mcpConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create MCP host: %w", err)
|
||||
}
|
||||
|
||||
// Initialize servers
|
||||
ctx := context.Background()
|
||||
if err := host.InitServers(ctx); err != nil {
|
||||
return fmt.Errorf("failed to initialize MCP servers: %w", err)
|
||||
}
|
||||
defer host.CloseServers()
|
||||
|
||||
// Export tools to JSON
|
||||
// If dump flag is set, dump MCP server tools to JSON file
|
||||
if dumpPath != "" {
|
||||
if err := host.ExportToolsToJSON(ctx, dumpPath); err != nil {
|
||||
return err
|
||||
}
|
||||
return host.ExportToolsToJSON(context.Background(), dumpPath)
|
||||
}
|
||||
return nil
|
||||
|
||||
// Create chat session
|
||||
chat, err := host.NewChat(context.Background(), systemPromptFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create chat session: %w", err)
|
||||
}
|
||||
|
||||
// Start chat
|
||||
return chat.Start()
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mcpConfigPath string
|
||||
dumpPath string
|
||||
mcpConfigPath string
|
||||
dumpPath string
|
||||
systemPromptFile string
|
||||
)
|
||||
|
||||
func init() {
|
||||
CmdMCPHost.Flags().StringVarP(&mcpConfigPath, "mcp-config", "c", "$HOME/.hrp/mcp.json", "path to the MCP config file")
|
||||
CmdMCPHost.Flags().StringVar(&dumpPath, "dump", "", "path to save the exported tools JSON file")
|
||||
CmdMCPHost.Flags().StringVar(&systemPromptFile, "system-prompt", "", "path to system prompt JSON file")
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
v5.0.0-beta-2505161430
|
||||
v5.0.0-beta-2505162305
|
||||
|
||||
12
parser.go
12
parser.go
@@ -24,13 +24,10 @@ import (
|
||||
)
|
||||
|
||||
func NewParser() *Parser {
|
||||
return &Parser{
|
||||
ctx: context.Background(),
|
||||
}
|
||||
return &Parser{}
|
||||
}
|
||||
|
||||
type Parser struct {
|
||||
ctx context.Context
|
||||
Plugin funplugin.IPlugin // plugin is used to call functions
|
||||
MCPHost *mcphost.MCPHost
|
||||
}
|
||||
@@ -308,15 +305,16 @@ func (p *Parser) CallFunc(funcName string, arguments ...interface{}) (interface{
|
||||
}
|
||||
|
||||
// CallMCPTool calls a MCP tool on a specific MCP server
|
||||
func (p *Parser) CallMCPTool(serverName, funcName string, arguments map[string]interface{}) (interface{}, error) {
|
||||
func (p *Parser) CallMCPTool(ctx context.Context, serverName,
|
||||
funcName string, arguments map[string]interface{}) (interface{}, error) {
|
||||
if p.MCPHost == nil {
|
||||
return nil, fmt.Errorf("mcphost is not initialized")
|
||||
}
|
||||
|
||||
tools := p.MCPHost.GetTools(p.ctx)
|
||||
tools := p.MCPHost.GetTools(ctx)
|
||||
log.Warn().Interface("tools", tools).Msg("tools")
|
||||
|
||||
result, err := p.MCPHost.InvokeTool(p.ctx, serverName, funcName, arguments)
|
||||
result, err := p.MCPHost.InvokeTool(ctx, serverName, funcName, arguments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "invoke tool %s/%s failed", serverName, funcName)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package hrp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -465,7 +466,7 @@ func TestCallMCPTool(t *testing.T) {
|
||||
|
||||
parser := caseRunner.GetParser()
|
||||
|
||||
resp, err := parser.CallMCPTool("filesystem", "read_file",
|
||||
resp, err := parser.CallMCPTool(context.Background(), "filesystem", "read_file",
|
||||
map[string]interface{}{"path": "internal/version/VERSION"})
|
||||
assert.Nil(t, err)
|
||||
t.Logf("resp: %v", resp)
|
||||
|
||||
372
pkg/mcphost/chat.go
Normal file
372
pkg/mcphost/chat.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/styles"
|
||||
"github.com/charmbracelet/huh/spinner"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/lipgloss/list"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/code"
|
||||
"github.com/httprunner/httprunner/v5/uixt/ai"
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Chat represents a chat session with LLM
|
||||
type Chat struct {
|
||||
model model.ToolCallingChatModel
|
||||
systemPrompt string
|
||||
history ai.ConversationHistory
|
||||
renderer *glamour.TermRenderer
|
||||
host *MCPHost
|
||||
tools []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// Tokyo Night theme colors
|
||||
var (
|
||||
tokyoPurple = lipgloss.Color("99") // #9d7cd8
|
||||
tokyoCyan = lipgloss.Color("73") // #7dcfff
|
||||
tokyoBlue = lipgloss.Color("111") // #7aa2f7
|
||||
tokyoGreen = lipgloss.Color("120") // #73daca
|
||||
tokyoRed = lipgloss.Color("203") // #f7768e
|
||||
tokyoOrange = lipgloss.Color("215") // #ff9e64
|
||||
tokyoFg = lipgloss.Color("189") // #c0caf5
|
||||
tokyoGray = lipgloss.Color("237") // #3b4261
|
||||
tokyoBg = lipgloss.Color("234") // #1a1b26
|
||||
|
||||
toolNameStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoCyan).
|
||||
Bold(true)
|
||||
|
||||
descriptionStyle = lipgloss.NewStyle().
|
||||
Foreground(tokyoFg).
|
||||
PaddingBottom(1)
|
||||
|
||||
contentStyle = lipgloss.NewStyle().
|
||||
Background(tokyoBg).
|
||||
PaddingLeft(4).
|
||||
PaddingRight(4)
|
||||
)
|
||||
|
||||
// NewChat creates a new chat session
|
||||
func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat, error) {
|
||||
// Get model config from environment variables
|
||||
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeGPT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
|
||||
// Create markdown renderer
|
||||
renderer, err := glamour.NewTermRenderer(
|
||||
glamour.WithStandardStyle(styles.TokyoNightStyle),
|
||||
glamour.WithWordWrap(getTerminalWidth()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create markdown renderer")
|
||||
}
|
||||
|
||||
// Load system prompt from file if provided
|
||||
systemPrompt := "chat to interact with MCP tools"
|
||||
if systemPromptFile != "" {
|
||||
customPrompt, err := loadSystemPrompt(systemPromptFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load system prompt")
|
||||
}
|
||||
if customPrompt != "" {
|
||||
systemPrompt = customPrompt
|
||||
}
|
||||
}
|
||||
|
||||
// convert MCP tools to eino tool infos
|
||||
einoTools, err := h.GetEinoToolInfos(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get eino tool infos")
|
||||
}
|
||||
|
||||
toolCallingModel, err := model.WithTools(einoTools)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
|
||||
return &Chat{
|
||||
model: toolCallingModel,
|
||||
systemPrompt: systemPrompt,
|
||||
history: ai.ConversationHistory{},
|
||||
renderer: renderer,
|
||||
host: h,
|
||||
tools: einoTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadSystemPrompt loads the system prompt from a JSON file
|
||||
func loadSystemPrompt(filePath string) (string, error) {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading prompt file: %v", err)
|
||||
}
|
||||
|
||||
// Read file content directly as prompt
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Start starts the chat session
|
||||
func (c *Chat) Start() error {
|
||||
// Add system message
|
||||
c.history = ai.ConversationHistory{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: c.systemPrompt,
|
||||
},
|
||||
}
|
||||
|
||||
c.showWelcome()
|
||||
|
||||
for {
|
||||
fmt.Print("\nYou: ")
|
||||
input, err := readInput()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle commands
|
||||
if strings.HasPrefix(input, "/") {
|
||||
if err := c.handleCommand(input); err != nil {
|
||||
log.Error().Err(err).Msg("failed to handle command")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// run prompt with MCP tools
|
||||
if err := c.runPrompt(input); err != nil {
|
||||
log.Error().Err(err).Msg("chat error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runPrompt run prompt with MCP tools
|
||||
func (c *Chat) runPrompt(prompt string) error {
|
||||
// Create user message
|
||||
userMsg := &schema.Message{
|
||||
Role: schema.User,
|
||||
Content: prompt,
|
||||
}
|
||||
c.history = append(c.history, userMsg)
|
||||
for {
|
||||
ctx := context.Background()
|
||||
spinner.New().Type(spinner.Dots).Title("Thinking...").Run()
|
||||
resp, err := c.model.Generate(ctx, c.history)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
toolCalls := resp.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
for _, toolCall := range toolCalls {
|
||||
parts := strings.SplitN(toolCall.Function.Name, "__", 2)
|
||||
if len(parts) != 2 {
|
||||
log.Error().Msgf("invalid tool name: %s", toolCall.Function.Name)
|
||||
continue
|
||||
}
|
||||
serverName, toolName := parts[0], parts[1]
|
||||
args := toolCall.Function.Arguments
|
||||
|
||||
// Unmarshal tool arguments from JSON string
|
||||
var argsMap map[string]interface{}
|
||||
if err := sonic.UnmarshalString(args, &argsMap); err != nil {
|
||||
log.Error().Err(err).Str("args", args).Msg("failed to unmarshal tool arguments")
|
||||
continue
|
||||
}
|
||||
|
||||
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("tool call failed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Format tool result
|
||||
resultStr := ""
|
||||
if result != nil && len(result.Content) > 0 {
|
||||
for _, item := range result.Content {
|
||||
resultStr += fmt.Sprintf("%v\n", item)
|
||||
}
|
||||
} else {
|
||||
resultStr = fmt.Sprintf("%+v", result)
|
||||
}
|
||||
|
||||
// Add tool result to history
|
||||
toolMsg := &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: resultStr,
|
||||
}
|
||||
c.history = append(c.history, toolMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Add assistant's response to history
|
||||
c.history = append(c.history, resp)
|
||||
|
||||
// Render and display response
|
||||
if rendered, err := c.renderer.Render(resp.Content); err == nil {
|
||||
fmt.Printf("\nAssistant: %s\n", rendered)
|
||||
} else {
|
||||
fmt.Printf("\nAssistant: %s\n", resp.Content)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// showWelcome show welcome and help information
|
||||
func (c *Chat) showWelcome() {
|
||||
markdown := fmt.Sprintf(`# Welcome to HttpRunner MCPHost Chat!
|
||||
|
||||
## Available Commands
|
||||
|
||||
The following commands are available:
|
||||
|
||||
- **/help**: Show this help message
|
||||
- **/tools**: List all available tools
|
||||
- **/history**: Display conversation history
|
||||
- **/clear**: Clear conversation history
|
||||
- **/quit**: Exit the chat session
|
||||
|
||||
You can also press Ctrl+C at any time to quit.
|
||||
|
||||
## Configurations
|
||||
|
||||
- **system-prompt**: %s
|
||||
- **mcp-config**: %s
|
||||
`, c.systemPrompt, c.host.config.ConfigPath)
|
||||
|
||||
str, err := c.renderer.Render(markdown)
|
||||
if err != nil {
|
||||
fmt.Println(markdown)
|
||||
} else {
|
||||
fmt.Print(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) handleCommand(cmd string) error {
|
||||
switch cmd {
|
||||
case "/help":
|
||||
c.showWelcome()
|
||||
case "/tools":
|
||||
c.showTools()
|
||||
case "/history":
|
||||
c.showHistory()
|
||||
case "/clear":
|
||||
c.clearHistory()
|
||||
case "/quit":
|
||||
fmt.Println("Goodbye!")
|
||||
os.Exit(0)
|
||||
default:
|
||||
fmt.Printf("Unknown command: %s\n", cmd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Chat) showHistory() {
|
||||
if len(c.history) <= 1 { // Only system message
|
||||
fmt.Println("No conversation history yet.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\nConversation History:")
|
||||
for _, msg := range c.history {
|
||||
if msg.Role == schema.System {
|
||||
continue
|
||||
}
|
||||
|
||||
role := "You"
|
||||
if msg.Role == schema.Assistant {
|
||||
role = "Assistant"
|
||||
}
|
||||
|
||||
// Render message content as markdown
|
||||
rendered, err := c.renderer.Render(msg.Content)
|
||||
if err != nil {
|
||||
rendered = msg.Content
|
||||
}
|
||||
|
||||
fmt.Printf("\n%s: %s\n", role, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) clearHistory() {
|
||||
// Keep only the system message
|
||||
systemMsg := c.history[0]
|
||||
c.history = ai.ConversationHistory{systemMsg}
|
||||
fmt.Println("Conversation history cleared.")
|
||||
}
|
||||
|
||||
func (c *Chat) showTools() {
|
||||
if c.host == nil {
|
||||
fmt.Println("No MCP host loaded.")
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
results := c.host.GetTools(ctx)
|
||||
if len(results) == 0 {
|
||||
fmt.Println("No MCP servers loaded.")
|
||||
return
|
||||
}
|
||||
width := getTerminalWidth()
|
||||
contentWidth := width - 12
|
||||
l := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoPurple).MarginRight(1))
|
||||
for server, tools := range results {
|
||||
serverList := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoCyan).MarginRight(1))
|
||||
if tools.Err != nil {
|
||||
serverList.Item(contentStyle.Render(fmt.Sprintf("Error: %v", tools.Err)))
|
||||
} else if len(tools.Tools) == 0 {
|
||||
serverList.Item(contentStyle.Render("No tools available."))
|
||||
} else {
|
||||
for _, tool := range tools.Tools {
|
||||
descStyle := lipgloss.NewStyle().Foreground(tokyoFg).Width(contentWidth).Align(lipgloss.Left)
|
||||
toolDesc := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoGreen).MarginRight(1)).Item(descStyle.Render(tool.Description))
|
||||
serverList.Item(toolNameStyle.Render(tool.Name)).Item(toolDesc)
|
||||
}
|
||||
}
|
||||
l.Item(server).Item(serverList)
|
||||
}
|
||||
containerStyle := lipgloss.NewStyle().Margin(2).Width(width)
|
||||
fmt.Print("\n" + containerStyle.Render(l.String()) + "\n")
|
||||
}
|
||||
|
||||
func readInput() (string, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(input), nil
|
||||
}
|
||||
|
||||
func getTerminalWidth() int {
|
||||
width, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
return 80 // Fallback width
|
||||
}
|
||||
return width - 20
|
||||
}
|
||||
50
pkg/mcphost/chat_test.go
Normal file
50
pkg/mcphost/chat_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package mcphost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewChat(t *testing.T) {
|
||||
systemPromptFile := "test_system_prompt.txt"
|
||||
_ = os.WriteFile(systemPromptFile, []byte("You are a helpful assistant."), 0o644)
|
||||
defer os.Remove(systemPromptFile)
|
||||
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), systemPromptFile)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, chat)
|
||||
assert.NotEmpty(t, chat.systemPrompt)
|
||||
assert.NotNil(t, chat.tools)
|
||||
}
|
||||
|
||||
func TestRunPromptWithNoToolCall(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = chat.runPrompt("hi")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.history) > 1)
|
||||
}
|
||||
|
||||
func TestRunPromptWithToolCall(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.tools) > 0)
|
||||
|
||||
err = chat.runPrompt("what is the weather in CA")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.history) > 1)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ const (
|
||||
|
||||
// MCPConfig represents the configuration for MCP servers
|
||||
type MCPConfig struct {
|
||||
ConfigPath string `json:"-"`
|
||||
MCPServers map[string]ServerConfigWrapper `json:"mcpServers"`
|
||||
}
|
||||
|
||||
@@ -120,6 +121,7 @@ func LoadMCPConfig(configPath string) (*MCPConfig, error) {
|
||||
if err := json.Unmarshal(configData, &config); err != nil {
|
||||
return nil, fmt.Errorf("error parsing config file: %w", err)
|
||||
}
|
||||
config.ConfigPath = configPath
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -185,33 +183,3 @@ func (h *MCPHost) ExportToolsToJSON(ctx context.Context, dumpPath string) error
|
||||
log.Info().Str("path", dumpPath).Msg("Tools records exported successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEinoTool returns an eino tool from the MCP server
|
||||
func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string) (tool.BaseTool, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// filter MCP server by serverName
|
||||
conn, exists := h.connections[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
}
|
||||
|
||||
if conn.Config.IsDisabled() {
|
||||
return nil, fmt.Errorf("server %s is disabled", serverName)
|
||||
}
|
||||
|
||||
// get tools from MCP server and convert to eino tools
|
||||
tools, err := mcpp.GetTools(ctx, &mcpp.Config{
|
||||
Cli: conn.Client,
|
||||
ToolNameList: []string{toolName},
|
||||
})
|
||||
if err != nil || len(tools) == 0 {
|
||||
log.Error().Err(err).
|
||||
Str("server", serverName).Str("tool", toolName).
|
||||
Msg("get MCP tool failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tools[0], nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -17,12 +16,8 @@ func TestConvertToolsToRecordsFromFile(t *testing.T) {
|
||||
hub, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = hub.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// use ExportToolsToJSON to dump tools to JSON file
|
||||
err = hub.ExportToolsToJSON(ctx, "./tools_records.json")
|
||||
err = hub.ExportToolsToJSON(context.Background(), "./tools_records.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// read the exported JSON file
|
||||
@@ -240,21 +235,3 @@ func TestConvertToolsToRecords(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallEinoTool(t *testing.T) {
|
||||
hub, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = hub.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
einoTool, err := hub.GetEinoTool(ctx, "weather", "get_alerts")
|
||||
require.NoError(t, err)
|
||||
t.Logf("Tool: %v", einoTool)
|
||||
|
||||
tool := einoTool.(tool.InvokableTool)
|
||||
result, err := tool.InvokableRun(ctx, `{"state": "CA"}`)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Result: %v", result)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/internal/version"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -42,10 +45,18 @@ func NewMCPHost(configPath string) (*MCPHost, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MCPHost{
|
||||
|
||||
host := &MCPHost{
|
||||
connections: make(map[string]*Connection),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Initialize MCP servers
|
||||
if err := host.InitServers(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize MCP servers: %w", err)
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// parseHeaders parses header strings into a map
|
||||
@@ -296,3 +307,66 @@ func (h *MCPHost) CloseServers() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEinoTool returns an eino tool from the MCP server
|
||||
func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string) (tool.BaseTool, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// filter MCP server by serverName
|
||||
conn, exists := h.connections[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
}
|
||||
|
||||
if conn.Config.IsDisabled() {
|
||||
return nil, fmt.Errorf("server %s is disabled", serverName)
|
||||
}
|
||||
|
||||
// get tools from MCP server and convert to eino tools
|
||||
tools, err := mcpp.GetTools(ctx, &mcpp.Config{
|
||||
Cli: conn.Client,
|
||||
ToolNameList: []string{toolName},
|
||||
})
|
||||
if err != nil || len(tools) == 0 {
|
||||
log.Error().Err(err).
|
||||
Str("server", serverName).Str("tool", toolName).
|
||||
Msg("get MCP tool failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tools[0], nil
|
||||
}
|
||||
|
||||
// GetEinoToolInfos convert MCP tools to eino tool infos
|
||||
func (h *MCPHost) GetEinoToolInfos(ctx context.Context) ([]*schema.ToolInfo, error) {
|
||||
var allTools []*schema.ToolInfo
|
||||
for serverName, serverTools := range h.GetTools(ctx) {
|
||||
if serverTools.Err != nil {
|
||||
log.Error().
|
||||
Err(serverTools.Err).
|
||||
Str("server", serverName).
|
||||
Msg("Error fetching tools")
|
||||
continue
|
||||
}
|
||||
for _, tool := range serverTools.Tools {
|
||||
einoTool, err := h.GetEinoTool(ctx, serverName, tool.Name)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get eino tool")
|
||||
continue
|
||||
}
|
||||
einoToolInfo, err := einoTool.Info(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get eino tool info")
|
||||
continue
|
||||
}
|
||||
allTools = append(allTools, einoToolInfo)
|
||||
}
|
||||
log.Info().
|
||||
Str("server", serverName).
|
||||
Int("count", len(serverTools.Tools)).
|
||||
Msg("eino tool infos loaded")
|
||||
}
|
||||
|
||||
return allTools, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -27,10 +28,6 @@ func TestInitServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connections are established
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
assert.Contains(t, host.connections, "filesystem")
|
||||
@@ -41,10 +38,6 @@ func TestGetClient(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test getting existing client
|
||||
client, err := host.GetClient("weather")
|
||||
require.NoError(t, err)
|
||||
@@ -61,9 +54,6 @@ func TestGetTools(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
assert.Contains(t, tools, "weather")
|
||||
@@ -90,8 +80,6 @@ func TestGetTool(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test getting existing tool
|
||||
tool, err := host.GetTool(ctx, "weather", "get_alerts")
|
||||
@@ -115,8 +103,6 @@ func TestInvokeTool(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test invoking existing tool
|
||||
result, err := host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
@@ -140,12 +126,23 @@ func TestInvokeTool(t *testing.T) {
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCloseServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
func TestCallEinoTool(t *testing.T) {
|
||||
hub, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
einoTool, err := hub.GetEinoTool(ctx, "weather", "get_alerts")
|
||||
require.NoError(t, err)
|
||||
t.Logf("Tool: %v", einoTool)
|
||||
|
||||
tool := einoTool.(tool.InvokableTool)
|
||||
result, err := tool.InvokableRun(ctx, `{"state": "CA"}`)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Result: %v", result)
|
||||
}
|
||||
|
||||
func TestCloseServers(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify servers are connected
|
||||
@@ -163,17 +160,14 @@ func TestConcurrentOperations(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent tool invocations
|
||||
done := make(chan bool)
|
||||
timeout := time.After(30 * time.Second) // Increase timeout to 30 seconds
|
||||
|
||||
for i := 0; i < 3; i++ { // Reduce number of concurrent operations to 3
|
||||
go func() {
|
||||
result, err := host.InvokeTool(ctx, "weather", "get_alerts",
|
||||
result, err := host.InvokeTool(
|
||||
context.Background(), "weather", "get_alerts",
|
||||
map[string]interface{}{"state": "CA"},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
@@ -197,10 +191,6 @@ func TestDisabledServer(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = host.InitServers(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only enabled servers are connected
|
||||
assert.Equal(t, 2, len(host.connections))
|
||||
assert.Contains(t, host.connections, "filesystem")
|
||||
@@ -214,6 +204,7 @@ func TestDisabledServer(t *testing.T) {
|
||||
assert.Nil(t, client)
|
||||
|
||||
// Test getting tools from disabled server
|
||||
ctx := context.Background()
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
assert.Contains(t, tools, "filesystem")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package hrp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
@@ -324,11 +323,6 @@ func NewCaseRunner(testcase TestCase, hrpRunner *HRPRunner) (*CaseRunner, error)
|
||||
log.Error().Err(err).Msg("init MCP hub failed")
|
||||
return nil, err
|
||||
}
|
||||
err = mcpHost.InitServers(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init MCP servers failed")
|
||||
return nil, err
|
||||
}
|
||||
caseRunner.parser.MCPHost = mcpHost
|
||||
log.Info().Str("mcpConfigPath", config.MCPConfigPath).Msg("mcp server loaded")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -31,13 +30,6 @@ func (r *Router) InitMCPHost(configPath string) error {
|
||||
log.Error().Err(err).Msg("init MCP host failed")
|
||||
return err
|
||||
}
|
||||
|
||||
err = mcpHost.InitServers(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init MCP servers failed")
|
||||
return err
|
||||
}
|
||||
|
||||
r.mcpHost = mcpHost
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user