refactor: mcphost call tools

This commit is contained in:
lilong.129
2025-05-17 11:32:51 +08:00
parent 8346fb179c
commit 5d8c22f729
6 changed files with 451 additions and 329 deletions

View File

@@ -24,44 +24,6 @@ import (
"golang.org/x/term"
)
var (
// Tokyo Night theme colors
tokyoPurple = lipgloss.Color("99") // #9d7cd8
tokyoCyan = lipgloss.Color("73") // #7dcfff
tokyoBlue = lipgloss.Color("111") // #7aa2f7
tokyoGreen = lipgloss.Color("120") // #73daca
tokyoRed = lipgloss.Color("203") // #f7768e
tokyoOrange = lipgloss.Color("215") // #ff9e64
tokyoFg = lipgloss.Color("189") // #c0caf5
tokyoGray = lipgloss.Color("237") // #3b4261
tokyoBg = lipgloss.Color("234") // #1a1b26
promptStyle = lipgloss.NewStyle().
Foreground(tokyoBlue).
PaddingLeft(2)
responseStyle = lipgloss.NewStyle().
Foreground(tokyoFg).
PaddingLeft(2)
errorStyle = lipgloss.NewStyle().
Foreground(tokyoRed).
Bold(true)
toolNameStyle = lipgloss.NewStyle().
Foreground(tokyoCyan).
Bold(true)
descriptionStyle = lipgloss.NewStyle().
Foreground(tokyoFg).
PaddingBottom(1)
contentStyle = lipgloss.NewStyle().
Background(tokyoBg).
PaddingLeft(4).
PaddingRight(4)
)
// NewChat creates a new chat session
func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat, error) {
// Get model config from environment variables
@@ -170,7 +132,7 @@ func (c *Chat) Start() error {
// run prompt with MCP tools
if err := c.runPrompt(input); err != nil {
log.Error().Err(err).Msg("chat error")
log.Error().Err(err).Msg("run prompt error")
}
}
}
@@ -185,75 +147,110 @@ func (c *Chat) runPrompt(prompt string) error {
Content: prompt,
}
c.history = append(c.history, userMsg)
for {
ctx := context.Background()
var resp *schema.Message
var err error
action := func() {
resp, err = c.model.Generate(ctx, c.history)
}
_ = spinner.New().Title("Thinking...").Action(action).Run()
if err != nil {
return err
// Call LLM model to get response
ctx := context.Background()
var message *schema.Message
var modelErr error
_ = spinner.New().Title("Thinking...").Action(func() {
message, modelErr = c.model.Generate(ctx, c.history)
}).Run()
if modelErr != nil {
return modelErr
}
// Log usage statistics
if usage := message.ResponseMeta.Usage; usage != nil {
log.Debug().Int("input_tokens", usage.PromptTokens).
Int("output_tokens", usage.CompletionTokens).
Int("total_tokens", usage.TotalTokens).Msg("Usage statistics")
}
// Handle tool calls
toolCalls := message.ToolCalls
if len(toolCalls) > 0 {
return c.handleToolCalls(ctx, toolCalls)
}
// Add assistant's response to history
toolMsg := &schema.Message{
Role: schema.Assistant,
Content: message.Content,
}
c.history = append(c.history, toolMsg)
c.renderContent("Assistant", message.Content)
return nil
}
func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall) error {
for _, toolCall := range toolCalls {
serverToolName := toolCall.Function.Name
toolArgs := toolCall.Function.Arguments
log.Debug().Str("name", serverToolName).Str("args", toolArgs).Msg("handle tool call")
// Parse tool name
parts := strings.SplitN(serverToolName, "__", 2)
if len(parts) != 2 {
log.Error().Str("name", serverToolName).Msg("invalid tool name")
continue
}
serverName, toolName := parts[0], parts[1]
// Handle tool calls
toolCalls := resp.ToolCalls
if len(toolCalls) > 0 {
for _, toolCall := range toolCalls {
parts := strings.SplitN(toolCall.Function.Name, "__", 2)
if len(parts) != 2 {
log.Error().Msgf("invalid tool name: %s", toolCall.Function.Name)
continue
}
serverName, toolName := parts[0], parts[1]
args := toolCall.Function.Arguments
// Unmarshal tool arguments from JSON string
var argsMap map[string]interface{}
if err := sonic.UnmarshalString(args, &argsMap); err != nil {
log.Error().Err(err).Str("args", args).Msg("failed to unmarshal tool arguments")
continue
}
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
if err != nil {
log.Error().Err(err).Msg("tool call failed")
continue
}
// Format tool result
resultStr := ""
if result != nil && len(result.Content) > 0 {
for _, item := range result.Content {
resultStr += fmt.Sprintf("%v\n", item)
}
} else {
resultStr = fmt.Sprintf("%+v", result)
}
// Add tool result to history
toolMsg := &schema.Message{
Role: schema.Assistant,
Content: resultStr,
}
c.history = append(c.history, toolMsg)
}
// Unmarshal tool arguments from JSON string
var argsMap map[string]interface{}
if err := sonic.UnmarshalString(toolArgs, &argsMap); err != nil {
log.Error().Err(err).Str("args", toolArgs).Msg("failed to unmarshal tool arguments")
continue
}
// Add assistant's response to history
c.history = append(c.history, resp)
// Render and display response
if rendered, err := c.renderer.Render(resp.Content); err == nil {
fmt.Printf("\n%s", responseStyle.Render("Assistant: "+rendered))
} else {
fmt.Printf("\n%s", errorStyle.Render("Assistant: "+resp.Content))
// Invoke tool
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
if err != nil {
log.Error().Err(err).Msg("invoke tool failed")
continue
}
return nil
// Format tool result
resultStr := ""
if result != nil && len(result.Content) > 0 {
for _, item := range result.Content {
resultStr += fmt.Sprintf("%v\n", item)
}
} else {
resultStr = fmt.Sprintf("%+v", result)
}
c.renderContent("Tool result", resultStr)
// Add tool result to history
toolMsg := &schema.Message{
Role: schema.Tool,
Content: resultStr,
ToolCallID: toolCall.ID,
}
c.history = append(c.history, toolMsg)
}
return nil
}
// handleCommand handles commands
func (c *Chat) handleCommand(cmd string) error {
switch cmd {
case "/help":
c.showWelcome()
case "/tools":
c.showTools()
case "/history":
c.showHistory()
case "/clear":
c.clearHistory()
case "/quit":
fmt.Println("Goodbye!")
os.Exit(0)
default:
fmt.Printf("Unknown command: %s\n", cmd)
}
return nil
}
// showWelcome show welcome and help information
@@ -278,31 +275,7 @@ You can also press Ctrl+C at any time to quit.
- **mcp-config**: %s
`, c.systemPrompt, c.host.config.ConfigPath)
str, err := c.renderer.Render(markdown)
if err != nil {
fmt.Println(markdown)
} else {
fmt.Print(str)
}
}
func (c *Chat) handleCommand(cmd string) error {
switch cmd {
case "/help":
c.showWelcome()
case "/tools":
c.showTools()
case "/history":
c.showHistory()
case "/clear":
c.clearHistory()
case "/quit":
fmt.Println("Goodbye!")
os.Exit(0)
default:
fmt.Printf("Unknown command: %s\n", cmd)
}
return nil
c.renderContent("", markdown)
}
func (c *Chat) showHistory() {
@@ -321,14 +294,7 @@ func (c *Chat) showHistory() {
if msg.Role == schema.Assistant {
role = "Assistant"
}
// Render message content as markdown
rendered, err := c.renderer.Render(msg.Content)
if err != nil {
rendered = msg.Content
}
fmt.Printf("\n%s: %s\n", role, rendered)
c.renderContent(role, msg.Content)
}
}
@@ -374,6 +340,19 @@ func (c *Chat) showTools() {
fmt.Print("\n" + containerStyle.Render(l.String()) + "\n")
}
// Render and display content
func (c *Chat) renderContent(title, content string) {
output, err := c.renderer.Render(content)
if err != nil {
log.Error().Err(err).Msg("render content failed")
output = content
}
if title != "" {
title = title + ": "
}
fmt.Printf("\n%s", responseStyle.Render(title+output))
}
// loadSystemPrompt loads the system prompt from a JSON file
func loadSystemPrompt(filePath string) (string, error) {
// Check if file exists
@@ -397,3 +376,41 @@ func getTerminalWidth() int {
}
return width - 20
}
var (
// Tokyo Night theme colors
tokyoPurple = lipgloss.Color("99") // #9d7cd8
tokyoCyan = lipgloss.Color("73") // #7dcfff
tokyoBlue = lipgloss.Color("111") // #7aa2f7
tokyoGreen = lipgloss.Color("120") // #73daca
tokyoRed = lipgloss.Color("203") // #f7768e
tokyoOrange = lipgloss.Color("215") // #ff9e64
tokyoFg = lipgloss.Color("189") // #c0caf5
tokyoGray = lipgloss.Color("237") // #3b4261
tokyoBg = lipgloss.Color("234") // #1a1b26
promptStyle = lipgloss.NewStyle().
Foreground(tokyoBlue).
PaddingLeft(2)
responseStyle = lipgloss.NewStyle().
Foreground(tokyoFg).
PaddingLeft(2)
errorStyle = lipgloss.NewStyle().
Foreground(tokyoRed).
Bold(true)
toolNameStyle = lipgloss.NewStyle().
Foreground(tokyoCyan).
Bold(true)
descriptionStyle = lipgloss.NewStyle().
Foreground(tokyoFg).
PaddingBottom(1)
contentStyle = lipgloss.NewStyle().
Background(tokyoBg).
PaddingLeft(4).
PaddingRight(4)
)

View File

@@ -36,15 +36,15 @@ func TestRunPromptWithNoToolCall(t *testing.T) {
assert.True(t, len(chat.history) > 1)
}
// func TestRunPromptWithToolCall(t *testing.T) {
// host, err := NewMCPHost("./testdata/test.mcp.json")
// require.NoError(t, err)
func TestRunPromptWithToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
// chat, err := host.NewChat(context.Background(), "")
// assert.NoError(t, err)
// assert.True(t, len(chat.tools) > 0)
chat, err := host.NewChat(context.Background(), "")
assert.NoError(t, err)
assert.True(t, len(chat.tools) > 0)
// err = chat.runPrompt("what is the weather in CA")
// assert.NoError(t, err)
// assert.True(t, len(chat.history) > 1)
// }
err = chat.runPrompt("what is the weather in CA")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
}

View File

@@ -19,13 +19,6 @@ import (
"github.com/rs/zerolog/log"
)
// MCPTools represents tools from a single MCP server
type MCPTools struct {
ServerName string
Tools []mcp.Tool
Err error
}
// MCPHost manages MCP server connections and tools
type MCPHost struct {
mu sync.RWMutex
@@ -39,6 +32,13 @@ type Connection struct {
Config ServerConfig
}
// MCPTools represents tools from a single MCP server
type MCPTools struct {
ServerName string
Tools []mcp.Tool
Err error
}
// NewMCPHost creates a new MCPHost instance
func NewMCPHost(configPath string) (*MCPHost, error) {
config, err := LoadMCPConfig(configPath)
@@ -59,46 +59,6 @@ func NewMCPHost(configPath string) (*MCPHost, error) {
return host, nil
}
// parseHeaders parses header strings into a map
func parseHeaders(headerList []string) map[string]string {
headers := make(map[string]string)
for _, header := range headerList {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
return headers
}
// startStdioLog starts a goroutine to print stdio logs
func startStdioLog(stderr io.Reader, serverName string) {
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
}
}()
}
// prepareClientInitRequest creates a standard initialization request
func prepareClientInitRequest() mcp.InitializeRequest {
return mcp.InitializeRequest{
Params: struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities mcp.ClientCapabilities `json:"capabilities"`
ClientInfo mcp.Implementation `json:"clientInfo"`
}{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
Capabilities: mcp.ClientCapabilities{},
ClientInfo: mcp.Implementation{
Name: "hrp-mcphost",
Version: version.GetVersionInfo(),
},
},
}
}
// InitServers initializes all MCP servers
func (h *MCPHost) InitServers(ctx context.Context) error {
for name, server := range h.config.MCPServers {
@@ -113,19 +73,6 @@ func (h *MCPHost) InitServers(ctx context.Context) error {
return nil
}
// GetClient returns the client for the specified server
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, exists := h.connections[serverName]
if !exists {
return nil, fmt.Errorf("no connection found for server %s", serverName)
}
return conn.Client, nil
}
// connectToServer establishes connection to a single MCP server
func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config ServerConfig) error {
h.mu.Lock()
@@ -147,7 +94,8 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
// create client based on server type
switch cfg := config.(type) {
case SSEServerConfig:
mcpClient, err = client.NewSSEMCPClient(cfg.Url, client.WithHeaders(parseHeaders(cfg.Headers)))
mcpClient, err = client.NewSSEMCPClient(cfg.Url,
client.WithHeaders(parseHeaders(cfg.Headers)))
case STDIOServerConfig:
env := make([]string, 0, len(cfg.Env))
for k, v := range cfg.Env {
@@ -181,6 +129,37 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
return nil
}
// CloseServers closes all connected MCP servers
func (h *MCPHost) CloseServers() error {
h.mu.Lock()
defer h.mu.Unlock()
log.Info().Msg("Shutting down MCP servers...")
for name, conn := range h.connections {
if err := conn.Client.Close(); err != nil {
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
} else {
delete(h.connections, name)
log.Info().Str("name", name).Msg("Server closed")
}
}
return nil
}
// GetClient returns the client for the specified server
func (h *MCPHost) GetClient(serverName string) (client.MCPClient, error) {
h.mu.RLock()
defer h.mu.RUnlock()
conn, exists := h.connections[serverName]
if !exists {
return nil, fmt.Errorf("no connection found for server %s", serverName)
}
return conn.Client, nil
}
// GetTools returns all tools from all MCP servers
func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
h.mu.RLock()
@@ -189,10 +168,6 @@ func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
var results []MCPTools
for serverName, conn := range h.connections {
if conn.Config.IsDisabled() {
continue
}
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
log.Error().Err(err).Str("server", serverName).Msg("failed to get tools")
@@ -244,17 +219,6 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc
return nil, fmt.Errorf("tool %s not found", toolName)
}
// handleToolError handles tool execution errors
func handleToolError(result *mcp.CallToolResult) error {
if !result.IsError {
return nil
}
if len(result.Content) > 0 {
return fmt.Errorf("tool error: %v", result.Content[0])
}
return fmt.Errorf("tool error: unknown error")
}
// InvokeTool calls a tool with the given arguments
func (h *MCPHost) InvokeTool(ctx context.Context,
serverName, toolName string, arguments map[string]any,
@@ -300,24 +264,6 @@ func (h *MCPHost) InvokeTool(ctx context.Context,
return result, nil
}
// CloseServers closes all connected MCP servers
func (h *MCPHost) CloseServers() error {
h.mu.Lock()
defer h.mu.Unlock()
log.Info().Msg("Shutting down MCP servers...")
for name, conn := range h.connections {
if err := conn.Client.Close(); err != nil {
log.Error().Str("name", name).Err(err).Msg("Failed to close server")
} else {
delete(h.connections, name)
log.Info().Str("name", name).Msg("Server closed")
}
}
return nil
}
// GetEinoTool returns an eino tool for the given server and tool name
func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string) (tool.BaseTool, error) {
h.mu.RLock()
@@ -328,10 +274,6 @@ func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string)
return nil, fmt.Errorf("server not found: %s", serverName)
}
if conn.Config.IsDisabled() {
return nil, fmt.Errorf("server %s is disabled", serverName)
}
// get tools from MCP server and convert to eino tools
tools, err := mcpp.GetTools(ctx, &mcpp.Config{
Cli: conn.Client,
@@ -372,6 +314,7 @@ func (h *MCPHost) GetEinoToolInfos(ctx context.Context) ([]*schema.ToolInfo, err
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool info")
continue
}
einoToolInfo.Name = fmt.Sprintf("%s__%s", serverTools.ServerName, tool.Name)
tools = append(tools, einoToolInfo)
}
}
@@ -379,3 +322,54 @@ func (h *MCPHost) GetEinoToolInfos(ctx context.Context) ([]*schema.ToolInfo, err
log.Info().Int("count", len(tools)).Msg("eino tool infos loaded")
return tools, nil
}
// parseHeaders parses header strings into a map
func parseHeaders(headerList []string) map[string]string {
headers := make(map[string]string)
for _, header := range headerList {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
return headers
}
// startStdioLog starts a goroutine to print stdio logs
func startStdioLog(stderr io.Reader, serverName string) {
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
fmt.Fprintf(os.Stderr, "MCP Server %s: %s\n", serverName, scanner.Text())
}
}()
}
// prepareClientInitRequest creates a standard initialization request
func prepareClientInitRequest() mcp.InitializeRequest {
return mcp.InitializeRequest{
Params: struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities mcp.ClientCapabilities `json:"capabilities"`
ClientInfo mcp.Implementation `json:"clientInfo"`
}{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
Capabilities: mcp.ClientCapabilities{},
ClientInfo: mcp.Implementation{
Name: "hrp-mcphost",
Version: version.GetVersionInfo(),
},
},
}
}
// handleToolError handles tool execution errors
func handleToolError(result *mcp.CallToolResult) error {
if !result.IsError {
return nil
}
if len(result.Content) > 0 {
return fmt.Errorf("tool error: %v", result.Content[0])
}
return fmt.Errorf("tool error: unknown error")
}