refactor: mcphost planner

This commit is contained in:
lilong.129
2025-05-18 21:55:01 +08:00
parent e35d644acf
commit 3f1ee03529
13 changed files with 595 additions and 220 deletions

View File

@@ -13,10 +13,7 @@ import (
"github.com/charmbracelet/huh/spinner"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/lipgloss/list"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors"
@@ -25,27 +22,15 @@ import (
)
// NewChat creates a new chat session
func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat, error) {
func (h *MCPHost) NewChat(ctx context.Context) (*Chat, error) {
// Get model config from environment variables
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeGPT)
if err != nil {
return nil, err
}
model, err := openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
planner, err := ai.NewPlanner(ctx, modelConfig)
if err != nil {
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
// Load system prompt from file if provided
systemPrompt := "chat to interact with MCP tools"
if systemPromptFile != "" {
customPrompt, err := loadSystemPrompt(systemPromptFile)
if err != nil {
return nil, errors.Wrap(err, "failed to load system prompt")
}
if customPrompt != "" {
systemPrompt = customPrompt
}
return nil, err
}
// Convert MCP tools to eino tool infos
@@ -53,9 +38,8 @@ func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat,
if err != nil {
return nil, errors.Wrap(err, "failed to get eino tool infos")
}
toolCallingModel, err := model.WithTools(einoTools)
if err != nil {
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
if err := planner.RegisterTools(einoTools); err != nil {
return nil, err
}
// Create markdown renderer
@@ -68,35 +52,21 @@ func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat,
}
return &Chat{
model: toolCallingModel,
systemPrompt: systemPrompt,
history: ai.ConversationHistory{},
renderer: renderer,
host: h,
tools: einoTools,
planner: planner,
renderer: renderer,
host: h,
}, nil
}
// Chat represents a chat session with LLM
type Chat struct {
model model.ToolCallingChatModel
systemPrompt string
history ai.ConversationHistory
renderer *glamour.TermRenderer
host *MCPHost
tools []*schema.ToolInfo
host *MCPHost
planner *ai.Planner
renderer *glamour.TermRenderer
}
// Start starts the chat session
func (c *Chat) Start() error {
// Add system message
c.history = ai.ConversationHistory{
{
Role: schema.System,
Content: c.systemPrompt,
},
}
func (c *Chat) Start(ctx context.Context) error {
c.showWelcome()
for {
@@ -130,54 +100,42 @@ func (c *Chat) Start() error {
}
// run prompt with MCP tools
if err := c.runPrompt(input); err != nil {
if err := c.runPrompt(ctx, input); err != nil {
log.Error().Err(err).Msg("run prompt error")
}
}
}
// runPrompt run prompt with MCP tools
func (c *Chat) runPrompt(prompt string) error {
func (c *Chat) runPrompt(ctx context.Context, prompt string) error {
fmt.Printf("\n%s\n", promptStyle.Render("You: "+prompt))
// Create user message
userMsg := &schema.Message{
Role: schema.User,
Content: prompt,
planningOpts := &ai.PlanningOptions{
UserInstruction: "chat with MCP tools",
Message: &schema.Message{
Role: schema.User,
Content: prompt,
},
}
c.history = append(c.history, userMsg)
// Call LLM model to get response
ctx := context.Background()
var message *schema.Message
var modelErr error
// Call planner to get response
var result *ai.PlanningResult
var err error
_ = spinner.New().Title("Thinking...").Action(func() {
message, modelErr = c.model.Generate(ctx, c.history)
result, err = c.planner.Call(ctx, planningOpts)
}).Run()
if modelErr != nil {
return modelErr
}
// Log usage statistics
if usage := message.ResponseMeta.Usage; usage != nil {
log.Debug().Int("input_tokens", usage.PromptTokens).
Int("output_tokens", usage.CompletionTokens).
Int("total_tokens", usage.TotalTokens).Msg("Usage statistics")
if err != nil {
return err
}
// Handle tool calls
toolCalls := message.ToolCalls
toolCalls := result.ToolCalls
if len(toolCalls) > 0 {
return c.handleToolCalls(ctx, toolCalls)
}
// Add assistant's response to history
toolMsg := &schema.Message{
Role: schema.Assistant,
Content: message.Content,
}
c.history = append(c.history, toolMsg)
c.renderContent("Assistant", message.Content)
c.renderContent("Assistant", result.ActionSummary)
return nil
}
@@ -207,6 +165,12 @@ func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall)
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
if err != nil {
log.Error().Err(err).Msg("invoke tool failed")
toolMsg := &schema.Message{
Role: schema.Tool,
Content: fmt.Sprintf("invoke tool %s error: %v", serverToolName, err),
ToolCallID: toolCall.ID,
}
c.planner.History().Append(toolMsg)
continue
}
@@ -219,7 +183,7 @@ func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall)
} else {
resultStr = fmt.Sprintf("%+v", result)
}
c.renderContent("Tool result", resultStr)
c.renderContent("Tool Result", resultStr)
// Add tool result to history
toolMsg := &schema.Message{
@@ -227,7 +191,7 @@ func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall)
Content: resultStr,
ToolCallID: toolCall.ID,
}
c.history = append(c.history, toolMsg)
c.planner.History().Append(toolMsg)
}
return nil
}
@@ -242,7 +206,7 @@ func (c *Chat) handleCommand(cmd string) error {
case "/history":
c.showHistory()
case "/clear":
c.clearHistory()
c.planner.History().Clear()
case "/quit":
fmt.Println("Goodbye!")
os.Exit(0)
@@ -272,19 +236,19 @@ You can also press Ctrl+C at any time to quit.
- **system-prompt**: %s
- **mcp-config**: %s
`, c.systemPrompt, c.host.config.ConfigPath)
`, c.planner.SystemPrompt(), c.host.config.ConfigPath)
c.renderContent("", markdown)
}
func (c *Chat) showHistory() {
if len(c.history) <= 1 { // Only system message
if len(*c.planner.History()) <= 1 { // Only system message
fmt.Println("No conversation history yet.")
return
}
fmt.Println("\nConversation History:")
for _, msg := range c.history {
for _, msg := range *c.planner.History() {
if msg.Role == schema.System {
continue
}
@@ -292,18 +256,13 @@ func (c *Chat) showHistory() {
role := "You"
if msg.Role == schema.Assistant {
role = "Assistant"
} else if msg.Role == schema.Tool {
role = "Tool Result"
}
c.renderContent(role, msg.Content)
}
}
func (c *Chat) clearHistory() {
// Keep only the system message
systemMsg := c.history[0]
c.history = ai.ConversationHistory{systemMsg}
fmt.Println("Conversation history cleared.")
}
func (c *Chat) showTools() {
if c.host == nil {
fmt.Println("No MCP host loaded.")
@@ -352,22 +311,6 @@ func (c *Chat) renderContent(title, content string) {
fmt.Printf("\n%s", responseStyle.Render(title+output))
}
// loadSystemPrompt loads the system prompt from a JSON file
func loadSystemPrompt(filePath string) (string, error) {
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
}
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("error reading prompt file: %v", err)
}
// Read file content directly as prompt
return string(data), nil
}
func getTerminalWidth() int {
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {

View File

@@ -2,49 +2,32 @@ package mcphost
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewChat(t *testing.T) {
systemPromptFile := "test_system_prompt.txt"
_ = os.WriteFile(systemPromptFile, []byte("You are a helpful assistant."), 0o600)
defer os.Remove(systemPromptFile)
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), systemPromptFile)
assert.NoError(t, err)
assert.NotNil(t, chat)
assert.NotEmpty(t, chat.systemPrompt)
assert.NotNil(t, chat.tools)
}
func TestRunPromptWithNoToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), "")
chat, err := host.NewChat(context.Background())
assert.NoError(t, err)
err = chat.runPrompt("hi")
err = chat.runPrompt(context.Background(), "hi")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
assert.True(t, len(*chat.planner.History()) > 1)
}
func TestRunPromptWithToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), "")
chat, err := host.NewChat(context.Background())
assert.NoError(t, err)
assert.True(t, len(chat.tools) > 0)
err = chat.runPrompt("what is the weather in CA")
err = chat.runPrompt(context.Background(), "what is the weather in CA")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
assert.True(t, len(*chat.planner.History()) > 1)
}