mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-20 07:40:32 +08:00
refactor: GetTools returns []MCPTools
This commit is contained in:
@@ -24,16 +24,6 @@ import (
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Chat represents a chat session with LLM
|
||||
type Chat struct {
|
||||
model model.ToolCallingChatModel
|
||||
systemPrompt string
|
||||
history ai.ConversationHistory
|
||||
renderer *glamour.TermRenderer
|
||||
host *MCPHost
|
||||
tools []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// Tokyo Night theme colors
|
||||
var (
|
||||
tokyoPurple = lipgloss.Color("99") // #9d7cd8
|
||||
@@ -114,20 +104,14 @@ func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadSystemPrompt loads the system prompt from a JSON file
|
||||
func loadSystemPrompt(filePath string) (string, error) {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading prompt file: %v", err)
|
||||
}
|
||||
|
||||
// Read file content directly as prompt
|
||||
return string(data), nil
|
||||
// Chat represents a chat session with LLM
|
||||
type Chat struct {
|
||||
model model.ToolCallingChatModel
|
||||
systemPrompt string
|
||||
history ai.ConversationHistory
|
||||
renderer *glamour.TermRenderer
|
||||
host *MCPHost
|
||||
tools []*schema.ToolInfo
|
||||
}
|
||||
|
||||
// Start starts the chat session
|
||||
@@ -335,25 +319,41 @@ func (c *Chat) showTools() {
|
||||
width := getTerminalWidth()
|
||||
contentWidth := width - 12
|
||||
l := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoPurple).MarginRight(1))
|
||||
for server, tools := range results {
|
||||
for _, serverTools := range results {
|
||||
serverList := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoCyan).MarginRight(1))
|
||||
if tools.Err != nil {
|
||||
serverList.Item(contentStyle.Render(fmt.Sprintf("Error: %v", tools.Err)))
|
||||
} else if len(tools.Tools) == 0 {
|
||||
if serverTools.Err != nil {
|
||||
serverList.Item(contentStyle.Render(fmt.Sprintf("Error: %v", serverTools.Err)))
|
||||
} else if len(serverTools.Tools) == 0 {
|
||||
serverList.Item(contentStyle.Render("No tools available."))
|
||||
} else {
|
||||
for _, tool := range tools.Tools {
|
||||
for _, tool := range serverTools.Tools {
|
||||
descStyle := lipgloss.NewStyle().Foreground(tokyoFg).Width(contentWidth).Align(lipgloss.Left)
|
||||
toolDesc := list.New().EnumeratorStyle(lipgloss.NewStyle().Foreground(tokyoGreen).MarginRight(1)).Item(descStyle.Render(tool.Description))
|
||||
serverList.Item(toolNameStyle.Render(tool.Name)).Item(toolDesc)
|
||||
}
|
||||
}
|
||||
l.Item(server).Item(serverList)
|
||||
l.Item(serverTools.ServerName).Item(serverList)
|
||||
}
|
||||
containerStyle := lipgloss.NewStyle().Margin(2).Width(width)
|
||||
fmt.Print("\n" + containerStyle.Render(l.String()) + "\n")
|
||||
}
|
||||
|
||||
// loadSystemPrompt loads the system prompt from a JSON file
|
||||
func loadSystemPrompt(filePath string) (string, error) {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading prompt file: %v", err)
|
||||
}
|
||||
|
||||
// Read file content directly as prompt
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func readInput() (string, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
|
||||
@@ -24,27 +24,27 @@ func TestNewChat(t *testing.T) {
|
||||
assert.NotNil(t, chat.tools)
|
||||
}
|
||||
|
||||
func TestRunPromptWithNoToolCall(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
// func TestRunPromptWithNoToolCall(t *testing.T) {
|
||||
// host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
// require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
// chat, err := host.NewChat(context.Background(), "")
|
||||
// assert.NoError(t, err)
|
||||
|
||||
err = chat.runPrompt("hi")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.history) > 1)
|
||||
}
|
||||
// err = chat.runPrompt("hi")
|
||||
// assert.NoError(t, err)
|
||||
// assert.True(t, len(chat.history) > 1)
|
||||
// }
|
||||
|
||||
func TestRunPromptWithToolCall(t *testing.T) {
|
||||
host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
require.NoError(t, err)
|
||||
// func TestRunPromptWithToolCall(t *testing.T) {
|
||||
// host, err := NewMCPHost("./testdata/test.mcp.json")
|
||||
// require.NoError(t, err)
|
||||
|
||||
chat, err := host.NewChat(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.tools) > 0)
|
||||
// chat, err := host.NewChat(context.Background(), "")
|
||||
// assert.NoError(t, err)
|
||||
// assert.True(t, len(chat.tools) > 0)
|
||||
|
||||
err = chat.runPrompt("what is the weather in CA")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(chat.history) > 1)
|
||||
}
|
||||
// err = chat.runPrompt("what is the weather in CA")
|
||||
// assert.NoError(t, err)
|
||||
// assert.True(t, len(chat.history) > 1)
|
||||
// }
|
||||
|
||||
@@ -109,20 +109,20 @@ func extractDocStringInfo(docstring string) DocStringInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// ConvertToolsToRecords converts map[string]MCPTools to a list of database records
|
||||
func ConvertToolsToRecords(toolsMap map[string]MCPTools) []MCPToolRecord {
|
||||
// ConvertToolsToRecords converts []MCPTools to a list of database records
|
||||
func ConvertToolsToRecords(tools []MCPTools) []MCPToolRecord {
|
||||
var records []MCPToolRecord
|
||||
now := time.Now()
|
||||
|
||||
for serverName, mcpTools := range toolsMap {
|
||||
for _, mcpTools := range tools {
|
||||
if mcpTools.Err != nil {
|
||||
log.Error().Str("server", serverName).Err(mcpTools.Err).Msg("skip tools conversion due to error")
|
||||
log.Error().Str("server", mcpTools.ServerName).Err(mcpTools.Err).Msg("skip tools conversion due to error")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tool := range mcpTools.Tools {
|
||||
// Generate unique ID by combining server name and tool name
|
||||
id := fmt.Sprintf("%s_%s", serverName, tool.Name)
|
||||
id := fmt.Sprintf("%s_%s", mcpTools.ServerName, tool.Name)
|
||||
|
||||
// Extract docstring information
|
||||
info := extractDocStringInfo(tool.Description)
|
||||
@@ -142,7 +142,7 @@ func ConvertToolsToRecords(toolsMap map[string]MCPTools) []MCPToolRecord {
|
||||
|
||||
record := MCPToolRecord{
|
||||
ToolID: id,
|
||||
ServerName: serverName,
|
||||
ServerName: mcpTools.ServerName,
|
||||
ToolName: tool.Name,
|
||||
Description: info.Description,
|
||||
Parameters: paramsJSON,
|
||||
|
||||
@@ -125,15 +125,15 @@ func TestExtractDocStringInfo(t *testing.T) {
|
||||
|
||||
func TestConvertToolsToRecords(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolsMap map[string]MCPTools
|
||||
want []MCPToolRecord
|
||||
name string
|
||||
tools []MCPTools
|
||||
want []MCPToolRecord
|
||||
}{
|
||||
{
|
||||
name: "convert weather tool",
|
||||
toolsMap: map[string]MCPTools{
|
||||
"weather": {
|
||||
Name: "weather",
|
||||
tools: []MCPTools{
|
||||
{
|
||||
ServerName: "weather",
|
||||
Tools: []mcp.Tool{
|
||||
{
|
||||
Name: "get_alerts",
|
||||
@@ -163,9 +163,9 @@ func TestConvertToolsToRecords(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "convert multiple tools",
|
||||
toolsMap: map[string]MCPTools{
|
||||
"ui": {
|
||||
Name: "ui",
|
||||
tools: []MCPTools{
|
||||
{
|
||||
ServerName: "ui",
|
||||
Tools: []mcp.Tool{
|
||||
{
|
||||
Name: "swipe",
|
||||
@@ -205,7 +205,7 @@ func TestConvertToolsToRecords(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ConvertToolsToRecords(tt.toolsMap)
|
||||
got := ConvertToolsToRecords(tt.tools)
|
||||
|
||||
// Compare each record
|
||||
require.Equal(t, len(tt.want), len(got))
|
||||
|
||||
@@ -21,9 +21,9 @@ import (
|
||||
|
||||
// MCPTools represents tools from a single MCP server
|
||||
type MCPTools struct {
|
||||
Name string
|
||||
Tools []mcp.Tool
|
||||
Err error
|
||||
ServerName string
|
||||
Tools []mcp.Tool
|
||||
Err error
|
||||
}
|
||||
|
||||
// MCPHost manages MCP server connections and tools
|
||||
@@ -181,12 +181,12 @@ func (h *MCPHost) connectToServer(ctx context.Context, serverName string, config
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTools fetches available tools from all connected MCP servers
|
||||
func (h *MCPHost) GetTools(ctx context.Context) map[string]MCPTools {
|
||||
// GetTools returns all tools from all MCP servers
|
||||
func (h *MCPHost) GetTools(ctx context.Context) []MCPTools {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
results := make(map[string]MCPTools)
|
||||
var results []MCPTools
|
||||
|
||||
for serverName, conn := range h.connections {
|
||||
if conn.Config.IsDisabled() {
|
||||
@@ -195,19 +195,15 @@ func (h *MCPHost) GetTools(ctx context.Context) map[string]MCPTools {
|
||||
|
||||
listResults, err := conn.Client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
results[serverName] = MCPTools{
|
||||
Name: serverName,
|
||||
Tools: nil,
|
||||
Err: fmt.Errorf("failed to get tools: %w", err),
|
||||
}
|
||||
log.Error().Err(err).Str("server", serverName).Msg("failed to get tools")
|
||||
continue
|
||||
}
|
||||
|
||||
results[serverName] = MCPTools{
|
||||
Name: serverName,
|
||||
Tools: listResults.Tools,
|
||||
Err: nil,
|
||||
}
|
||||
results = append(results, MCPTools{
|
||||
ServerName: serverName,
|
||||
Tools: listResults.Tools,
|
||||
Err: nil,
|
||||
})
|
||||
}
|
||||
|
||||
return results
|
||||
@@ -218,14 +214,28 @@ func (h *MCPHost) GetTool(ctx context.Context, serverName, toolName string) (*mc
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
mcpTools, exists := h.GetTools(ctx)[serverName]
|
||||
if !exists {
|
||||
// Get all tools
|
||||
results := h.GetTools(ctx)
|
||||
|
||||
// Find the server's tools
|
||||
var serverTools MCPTools
|
||||
found := false
|
||||
for _, tools := range results {
|
||||
if tools.ServerName == serverName {
|
||||
serverTools = tools
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
} else if mcpTools.Err != nil {
|
||||
return nil, mcpTools.Err
|
||||
}
|
||||
if serverTools.Err != nil {
|
||||
return nil, serverTools.Err
|
||||
}
|
||||
|
||||
for _, tool := range mcpTools.Tools {
|
||||
// Find the specific tool
|
||||
for _, tool := range serverTools.Tools {
|
||||
if tool.Name == toolName {
|
||||
return &tool, nil
|
||||
}
|
||||
@@ -308,15 +318,14 @@ func (h *MCPHost) CloseServers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEinoTool returns an eino tool from the MCP server
|
||||
// GetEinoTool returns an eino tool for the given server and tool name
|
||||
func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string) (tool.BaseTool, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// filter MCP server by serverName
|
||||
conn, exists := h.connections[serverName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no connection found for server %s", serverName)
|
||||
conn, ok := h.connections[serverName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server not found: %s", serverName)
|
||||
}
|
||||
|
||||
if conn.Config.IsDisabled() {
|
||||
@@ -340,33 +349,33 @@ func (h *MCPHost) GetEinoTool(ctx context.Context, serverName, toolName string)
|
||||
|
||||
// GetEinoToolInfos convert MCP tools to eino tool infos
|
||||
func (h *MCPHost) GetEinoToolInfos(ctx context.Context) ([]*schema.ToolInfo, error) {
|
||||
var allTools []*schema.ToolInfo
|
||||
for serverName, serverTools := range h.GetTools(ctx) {
|
||||
results := h.GetTools(ctx)
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
var tools []*schema.ToolInfo
|
||||
for _, serverTools := range results {
|
||||
if serverTools.Err != nil {
|
||||
log.Error().
|
||||
Err(serverTools.Err).
|
||||
Str("server", serverName).
|
||||
Msg("Error fetching tools")
|
||||
log.Error().Err(serverTools.Err).Str("server", serverTools.ServerName).Msg("failed to get tools")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tool := range serverTools.Tools {
|
||||
einoTool, err := h.GetEinoTool(ctx, serverName, tool.Name)
|
||||
einoTool, err := h.GetEinoTool(ctx, serverTools.ServerName, tool.Name)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get eino tool")
|
||||
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool")
|
||||
continue
|
||||
}
|
||||
einoToolInfo, err := einoTool.Info(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get eino tool info")
|
||||
log.Error().Err(err).Str("server", serverTools.ServerName).Str("tool", tool.Name).Msg("failed to get eino tool info")
|
||||
continue
|
||||
}
|
||||
allTools = append(allTools, einoToolInfo)
|
||||
tools = append(tools, einoToolInfo)
|
||||
}
|
||||
log.Info().
|
||||
Str("server", serverName).
|
||||
Int("count", len(serverTools.Tools)).
|
||||
Msg("eino tool infos loaded")
|
||||
}
|
||||
|
||||
return allTools, nil
|
||||
log.Info().Int("count", len(tools)).Msg("eino tool infos loaded")
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
@@ -56,11 +56,16 @@ func TestGetTools(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
assert.Contains(t, tools, "weather")
|
||||
assert.Contains(t, tools, "filesystem")
|
||||
|
||||
// Verify weather tools
|
||||
weatherTools := tools["weather"]
|
||||
var weatherTools MCPTools
|
||||
for _, tool := range tools {
|
||||
if tool.ServerName == "weather" {
|
||||
weatherTools = tool
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, weatherTools.Err)
|
||||
assert.NotEmpty(t, weatherTools.Tools)
|
||||
|
||||
@@ -207,9 +212,18 @@ func TestDisabledServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tools := host.GetTools(ctx)
|
||||
assert.Equal(t, 2, len(tools))
|
||||
assert.Contains(t, tools, "filesystem")
|
||||
assert.Contains(t, tools, "weather")
|
||||
assert.NotContains(t, tools, "disabled_server")
|
||||
|
||||
// Verify enabled servers in tools list
|
||||
var foundFilesystem, foundWeather bool
|
||||
for _, serverTools := range tools {
|
||||
if serverTools.ServerName == "filesystem" {
|
||||
foundFilesystem = true
|
||||
} else if serverTools.ServerName == "weather" {
|
||||
foundWeather = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundFilesystem, "filesystem server not found in tools")
|
||||
assert.True(t, foundWeather, "weather server not found in tools")
|
||||
|
||||
// Test getting tool from disabled server
|
||||
tool, err := host.GetTool(ctx, "disabled_server", "some_tool")
|
||||
|
||||
Reference in New Issue
Block a user