refactor: mcphost planner

This commit is contained in:
lilong.129
2025-05-18 21:55:01 +08:00
parent e35d644acf
commit 3f1ee03529
13 changed files with 595 additions and 220 deletions

View File

@@ -27,24 +27,22 @@ var CmdMCPHost = &cobra.Command{
}
// Create chat session
chat, err := host.NewChat(context.Background(), systemPromptFile)
chat, err := host.NewChat(context.Background())
if err != nil {
return fmt.Errorf("failed to create chat session: %w", err)
}
// Start chat
return chat.Start()
return chat.Start(context.Background())
},
}
var (
mcpConfigPath string
dumpPath string
systemPromptFile string
mcpConfigPath string
dumpPath string
)
func init() {
CmdMCPHost.Flags().StringVarP(&mcpConfigPath, "mcp-config", "c", "$HOME/.hrp/mcp.json", "path to the MCP config file")
CmdMCPHost.Flags().StringVar(&dumpPath, "dump", "", "path to save the exported tools JSON file")
CmdMCPHost.Flags().StringVar(&systemPromptFile, "system-prompt", "", "path to system prompt JSON file")
}

View File

@@ -1 +1 @@
v5.0.0-beta-2505171220
v5.0.0-beta-2505182155

View File

@@ -13,10 +13,7 @@ import (
"github.com/charmbracelet/huh/spinner"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/lipgloss/list"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors"
@@ -25,27 +22,15 @@ import (
)
// NewChat creates a new chat session
func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat, error) {
func (h *MCPHost) NewChat(ctx context.Context) (*Chat, error) {
// Get model config from environment variables
modelConfig, err := ai.GetModelConfig(option.LLMServiceTypeGPT)
if err != nil {
return nil, err
}
model, err := openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
planner, err := ai.NewPlanner(ctx, modelConfig)
if err != nil {
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
// Load system prompt from file if provided
systemPrompt := "chat to interact with MCP tools"
if systemPromptFile != "" {
customPrompt, err := loadSystemPrompt(systemPromptFile)
if err != nil {
return nil, errors.Wrap(err, "failed to load system prompt")
}
if customPrompt != "" {
systemPrompt = customPrompt
}
return nil, err
}
// Convert MCP tools to eino tool infos
@@ -53,9 +38,8 @@ func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat,
if err != nil {
return nil, errors.Wrap(err, "failed to get eino tool infos")
}
toolCallingModel, err := model.WithTools(einoTools)
if err != nil {
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
if err := planner.RegisterTools(einoTools); err != nil {
return nil, err
}
// Create markdown renderer
@@ -68,35 +52,21 @@ func (h *MCPHost) NewChat(ctx context.Context, systemPromptFile string) (*Chat,
}
return &Chat{
model: toolCallingModel,
systemPrompt: systemPrompt,
history: ai.ConversationHistory{},
renderer: renderer,
host: h,
tools: einoTools,
planner: planner,
renderer: renderer,
host: h,
}, nil
}
// Chat represents a chat session with LLM
type Chat struct {
model model.ToolCallingChatModel
systemPrompt string
history ai.ConversationHistory
renderer *glamour.TermRenderer
host *MCPHost
tools []*schema.ToolInfo
host *MCPHost
planner *ai.Planner
renderer *glamour.TermRenderer
}
// Start starts the chat session
func (c *Chat) Start() error {
// Add system message
c.history = ai.ConversationHistory{
{
Role: schema.System,
Content: c.systemPrompt,
},
}
func (c *Chat) Start(ctx context.Context) error {
c.showWelcome()
for {
@@ -130,54 +100,42 @@ func (c *Chat) Start() error {
}
// run prompt with MCP tools
if err := c.runPrompt(input); err != nil {
if err := c.runPrompt(ctx, input); err != nil {
log.Error().Err(err).Msg("run prompt error")
}
}
}
// runPrompt run prompt with MCP tools
func (c *Chat) runPrompt(prompt string) error {
func (c *Chat) runPrompt(ctx context.Context, prompt string) error {
fmt.Printf("\n%s\n", promptStyle.Render("You: "+prompt))
// Create user message
userMsg := &schema.Message{
Role: schema.User,
Content: prompt,
planningOpts := &ai.PlanningOptions{
UserInstruction: "chat with MCP tools",
Message: &schema.Message{
Role: schema.User,
Content: prompt,
},
}
c.history = append(c.history, userMsg)
// Call LLM model to get response
ctx := context.Background()
var message *schema.Message
var modelErr error
// Call planner to get response
var result *ai.PlanningResult
var err error
_ = spinner.New().Title("Thinking...").Action(func() {
message, modelErr = c.model.Generate(ctx, c.history)
result, err = c.planner.Call(ctx, planningOpts)
}).Run()
if modelErr != nil {
return modelErr
}
// Log usage statistics
if usage := message.ResponseMeta.Usage; usage != nil {
log.Debug().Int("input_tokens", usage.PromptTokens).
Int("output_tokens", usage.CompletionTokens).
Int("total_tokens", usage.TotalTokens).Msg("Usage statistics")
if err != nil {
return err
}
// Handle tool calls
toolCalls := message.ToolCalls
toolCalls := result.ToolCalls
if len(toolCalls) > 0 {
return c.handleToolCalls(ctx, toolCalls)
}
// Add assistant's response to history
toolMsg := &schema.Message{
Role: schema.Assistant,
Content: message.Content,
}
c.history = append(c.history, toolMsg)
c.renderContent("Assistant", message.Content)
c.renderContent("Assistant", result.ActionSummary)
return nil
}
@@ -207,6 +165,12 @@ func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall)
result, err := c.host.InvokeTool(ctx, serverName, toolName, argsMap)
if err != nil {
log.Error().Err(err).Msg("invoke tool failed")
toolMsg := &schema.Message{
Role: schema.Tool,
Content: fmt.Sprintf("invoke tool %s error: %v", serverToolName, err),
ToolCallID: toolCall.ID,
}
c.planner.History().Append(toolMsg)
continue
}
@@ -219,7 +183,7 @@ func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall)
} else {
resultStr = fmt.Sprintf("%+v", result)
}
c.renderContent("Tool result", resultStr)
c.renderContent("Tool Result", resultStr)
// Add tool result to history
toolMsg := &schema.Message{
@@ -227,7 +191,7 @@ func (c *Chat) handleToolCalls(ctx context.Context, toolCalls []schema.ToolCall)
Content: resultStr,
ToolCallID: toolCall.ID,
}
c.history = append(c.history, toolMsg)
c.planner.History().Append(toolMsg)
}
return nil
}
@@ -242,7 +206,7 @@ func (c *Chat) handleCommand(cmd string) error {
case "/history":
c.showHistory()
case "/clear":
c.clearHistory()
c.planner.History().Clear()
case "/quit":
fmt.Println("Goodbye!")
os.Exit(0)
@@ -272,19 +236,19 @@ You can also press Ctrl+C at any time to quit.
- **system-prompt**: %s
- **mcp-config**: %s
`, c.systemPrompt, c.host.config.ConfigPath)
`, c.planner.SystemPrompt(), c.host.config.ConfigPath)
c.renderContent("", markdown)
}
func (c *Chat) showHistory() {
if len(c.history) <= 1 { // Only system message
if len(*c.planner.History()) <= 1 { // Only system message
fmt.Println("No conversation history yet.")
return
}
fmt.Println("\nConversation History:")
for _, msg := range c.history {
for _, msg := range *c.planner.History() {
if msg.Role == schema.System {
continue
}
@@ -292,18 +256,13 @@ func (c *Chat) showHistory() {
role := "You"
if msg.Role == schema.Assistant {
role = "Assistant"
} else if msg.Role == schema.Tool {
role = "Tool Result"
}
c.renderContent(role, msg.Content)
}
}
func (c *Chat) clearHistory() {
// Keep only the system message
systemMsg := c.history[0]
c.history = ai.ConversationHistory{systemMsg}
fmt.Println("Conversation history cleared.")
}
func (c *Chat) showTools() {
if c.host == nil {
fmt.Println("No MCP host loaded.")
@@ -352,22 +311,6 @@ func (c *Chat) renderContent(title, content string) {
fmt.Printf("\n%s", responseStyle.Render(title+output))
}
// loadSystemPrompt loads the system prompt from a JSON file
func loadSystemPrompt(filePath string) (string, error) {
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
}
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("error reading prompt file: %v", err)
}
// Read file content directly as prompt
return string(data), nil
}
func getTerminalWidth() int {
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {

View File

@@ -2,49 +2,32 @@ package mcphost
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewChat(t *testing.T) {
systemPromptFile := "test_system_prompt.txt"
_ = os.WriteFile(systemPromptFile, []byte("You are a helpful assistant."), 0o600)
defer os.Remove(systemPromptFile)
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), systemPromptFile)
assert.NoError(t, err)
assert.NotNil(t, chat)
assert.NotEmpty(t, chat.systemPrompt)
assert.NotNil(t, chat.tools)
}
func TestRunPromptWithNoToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), "")
chat, err := host.NewChat(context.Background())
assert.NoError(t, err)
err = chat.runPrompt("hi")
err = chat.runPrompt(context.Background(), "hi")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
assert.True(t, len(*chat.planner.History()) > 1)
}
func TestRunPromptWithToolCall(t *testing.T) {
host, err := NewMCPHost("./testdata/test.mcp.json")
require.NoError(t, err)
chat, err := host.NewChat(context.Background(), "")
chat, err := host.NewChat(context.Background())
assert.NoError(t, err)
assert.True(t, len(chat.tools) > 0)
err = chat.runPrompt("what is the weather in CA")
err = chat.runPrompt(context.Background(), "what is the weather in CA")
assert.NoError(t, err)
assert.True(t, len(chat.history) > 1)
assert.True(t, len(*chat.planner.History()) > 1)
}

View File

@@ -15,8 +15,8 @@ import (
// ILLMService 定义了 LLM 服务接口,包括规划和断言功能
type ILLMService interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
Assert(opts *AssertOptions) (*AssertionResponse, error)
Call(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error)
Assert(ctx context.Context, opts *AssertOptions) (*AssertionResponse, error)
}
func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) {
@@ -48,13 +48,13 @@ type combinedLLMService struct {
}
// Call 执行规划功能
func (c *combinedLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
return c.planner.Call(opts)
func (c *combinedLLMService) Call(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error) {
return c.planner.Call(ctx, opts)
}
// Assert 执行断言功能
func (c *combinedLLMService) Assert(opts *AssertOptions) (*AssertionResponse, error) {
return c.asserter.Assert(opts)
func (c *combinedLLMService) Assert(ctx context.Context, opts *AssertOptions) (*AssertionResponse, error) {
return c.asserter.Assert(ctx, opts)
}
// LLM model config env variables
@@ -95,12 +95,14 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
"env %s missed", EnvModelName)
}
temperature := float32(0.01)
maxTokens := 4096
temperature := float32(0.7)
modelConfig := &openai.ChatModelConfig{
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
MaxTokens: &maxTokens,
Temperature: &temperature,
}

View File

@@ -22,7 +22,7 @@ import (
// IAsserter interface defines the contract for assertion operations
type IAsserter interface {
Assert(opts *AssertOptions) (*AssertionResponse, error)
Assert(ctx context.Context, opts *AssertOptions) (*AssertionResponse, error)
}
// AssertOptions represents the input options for assertion
@@ -40,7 +40,6 @@ type AssertionResponse struct {
// Asserter handles assertion using different AI models
type Asserter struct {
ctx context.Context
modelConfig *ModelConfig
model model.ToolCallingChatModel
systemPrompt string
@@ -50,7 +49,6 @@ type Asserter struct {
// NewAsserter creates a new Asserter instance
func NewAsserter(ctx context.Context, modelConfig *ModelConfig) (*Asserter, error) {
asserter := &Asserter{
ctx: ctx,
modelConfig: modelConfig,
systemPrompt: defaultAssertionPrompt,
}
@@ -93,7 +91,7 @@ func NewAsserter(ctx context.Context, modelConfig *ModelConfig) (*Asserter, erro
}
// Assert performs the assertion check on the screenshot
func (a *Asserter) Assert(opts *AssertOptions) (*AssertionResponse, error) {
func (a *Asserter) Assert(ctx context.Context, opts *AssertOptions) (*AssertionResponse, error) {
// Validate input parameters
if err := validateAssertionInput(opts); err != nil {
return nil, errors.Wrap(err, "validate assertion parameters failed")
@@ -136,7 +134,7 @@ Here is the assertion. Please tell whether it is truthy according to the screens
// Call model service, generate response
logRequest(a.history)
startTime := time.Now()
resp, err := a.model.Generate(a.ctx, a.history)
resp, err := a.model.Generate(ctx, a.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", string(a.modelConfig.ModelType)).Msg("call model service for assertion")
if err != nil {

View File

@@ -54,7 +54,7 @@ func TestValidAssertions(t *testing.T) {
imageBase64, size, err := builtin.LoadImage(tc.imagePath)
require.NoError(t, err)
result, err := asserter.Assert(&AssertOptions{
result, err := asserter.Assert(context.Background(), &AssertOptions{
Assertion: tc.assertion,
Screenshot: imageBase64,
Size: size,
@@ -94,7 +94,7 @@ func TestInvalidParameters(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := asserter.Assert(&AssertOptions{
_, err := asserter.Assert(context.Background(), &AssertOptions{
Assertion: tc.assertion,
Screenshot: tc.screenshot,
Size: tc.size,

View File

@@ -15,7 +15,7 @@ import (
)
type IPlanner interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
Call(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error)
}
// PlanningOptions represents the input options for planning
@@ -27,21 +27,16 @@ type PlanningOptions struct {
// PlanningResult represents the result of planning
type PlanningResult struct {
NextActions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
Error string `json:"error,omitempty"`
ToolCalls []schema.ToolCall `json:"tool_calls"` // TODO: merge to NextActions
NextActions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
Error string `json:"error,omitempty"`
}
func NewPlanner(ctx context.Context, modelConfig *ModelConfig) (*Planner, error) {
planner := &Planner{
ctx: ctx,
modelConfig: modelConfig,
}
if modelConfig.ModelType == option.LLMServiceTypeUITARS {
planner.systemPrompt = uiTarsPlanningPrompt
} else {
planner.systemPrompt = defaultPlanningResponseJsonFormat
parser: NewLLMContentParser(modelConfig.ModelType),
}
var err error
@@ -54,27 +49,51 @@ func NewPlanner(ctx context.Context, modelConfig *ModelConfig) (*Planner, error)
}
type Planner struct {
ctx context.Context
modelConfig *ModelConfig
model model.ToolCallingChatModel
systemPrompt string
history ConversationHistory
modelConfig *ModelConfig
model model.ToolCallingChatModel
parser LLMContentParser
history ConversationHistory
tools []*schema.ToolInfo
}
func (p *Planner) SystemPrompt() string {
return p.parser.SystemPrompt()
}
func (p *Planner) History() *ConversationHistory {
return &p.history
}
func (p *Planner) RegisterTools(tools []*schema.ToolInfo) error {
if p.modelConfig.ModelType == option.LLMServiceTypeUITARS {
// tools have been registered in ui-tars system prompt
return nil
}
// register tools for models with function calling
toolCallingModel, err := p.model.WithTools(tools)
if err != nil {
return errors.Wrap(err, "failed to register tools")
}
p.tools = tools
p.model = toolCallingModel
return nil
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
func (p *Planner) Call(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error) {
// validate input parameters
if err := validatePlanningInput(opts); err != nil {
return nil, errors.Wrap(err, "validate planning parameters failed")
}
// prepare prompt
if len(p.history) == 0 {
if len(p.history) == 0 && opts.UserInstruction != "" {
// add system message
p.history = ConversationHistory{
{
Role: schema.System,
Content: p.systemPrompt + opts.UserInstruction,
Content: p.parser.SystemPrompt() + opts.UserInstruction,
},
}
}
@@ -84,50 +103,37 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// call model service, generate response
logRequest(p.history)
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
message, err := p.model.Generate(ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", string(p.modelConfig.ModelType)).Msg("call model service")
if err != nil {
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
}
logResponse(resp)
logResponse(message)
// parse result
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error())
// handle tool calls
if len(message.ToolCalls) > 0 {
// history will be appended with tool calls execution result
result := &PlanningResult{
ToolCalls: message.ToolCalls,
ActionSummary: message.Content,
}
return result, nil
}
// append assistant message
p.history.Append(&schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})
return result, nil
}
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
var parseActions []ParsedAction
var err error
if p.modelConfig.ModelType == option.LLMServiceTypeUITARS {
// parse Thought/Action format from UI-TARS
parseActions, err = parseThoughtAction(msg.Content)
if err != nil {
return nil, err
}
} else {
// parse JSON format, from VLM like openai/gpt-4o
parseActions, err = parseJSON(msg.Content)
if err != nil {
return nil, err
}
}
// process response
result, err := processVLMResponse(parseActions, size)
// parse message content to actions (tool calls)
result, err := p.parser.Parse(message.Content, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
result = &PlanningResult{
ActionSummary: message.Content,
Error: err.Error(),
}
log.Debug().Str("reason", err.Error()).Msg("parse content to actions failed")
// append assistant message
p.history.Append(&schema.Message{
Role: schema.Assistant,
Content: message.Content,
})
}
log.Info().

View File

@@ -8,11 +8,36 @@ import (
"strings"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
// LLMContentParser parses the content from the LLM response
// parser is corresponding to the model type and system prompt
type LLMContentParser interface {
SystemPrompt() string
Parse(content string, size types.Size) (*PlanningResult, error)
}
func NewLLMContentParser(modelType option.LLMServiceType) LLMContentParser {
switch modelType {
case option.LLMServiceTypeUITARS:
return &UITARSContentParser{
systemPrompt: uiTarsPlanningPrompt,
}
case option.LLMServiceTypeGPT:
return &JSONContentParser{
systemPrompt: defaultPlanningResponseJsonFormat,
}
default:
return &DefaultContentParser{
systemPrompt: defaultPlanningResponseStringFormat,
}
}
}
// ParsedAction represents a parsed action from the VLM response
type ParsedAction struct {
ActionType ActionType `json:"actionType"`
@@ -34,20 +59,28 @@ const (
ActionTypeScroll ActionType = "scroll"
)
// parseThoughtAction parses the Thought/Action format response
func parseThoughtAction(predictionText string) ([]ParsedAction, error) {
// UITARSContentParser parses the Thought/Action format response
type UITARSContentParser struct {
systemPrompt string
}
func (p *UITARSContentParser) SystemPrompt() string {
return p.systemPrompt
}
func (p *UITARSContentParser) Parse(content string, size types.Size) (*PlanningResult, error) {
thoughtRegex := regexp.MustCompile(`(?is)Thought:(.+?)Action:`)
actionRegex := regexp.MustCompile(`(?is)Action:(.+)`)
// extract Thought part
thoughtMatch := thoughtRegex.FindStringSubmatch(predictionText)
thoughtMatch := thoughtRegex.FindStringSubmatch(content)
var thought string
if len(thoughtMatch) > 1 {
thought = strings.TrimSpace(thoughtMatch[1])
}
// extract Action part, e.g. "click(start_box='(552,454)')"
actionMatch := actionRegex.FindStringSubmatch(predictionText)
actionMatch := actionRegex.FindStringSubmatch(content)
if len(actionMatch) < 2 {
return nil, errors.New("no action found in the response")
}
@@ -55,7 +88,17 @@ func parseThoughtAction(predictionText string) ([]ParsedAction, error) {
actionsText := strings.TrimSpace(actionMatch[1])
// parse action type and parameters
return parseActionText(actionsText, thought)
parseActions, err := parseActionText(actionsText, thought)
if err != nil {
return nil, err
}
// process response
result, err := processVLMResponse(parseActions, size)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
return result, nil
}
// parseActionText parses the action text to extract the action type and parameters
@@ -319,17 +362,25 @@ func validateTypeContent(action *ParsedAction) {
}
}
// parseJSON tries to parse the response as JSON format
func parseJSON(predictionText string) ([]ParsedAction, error) {
predictionText = strings.TrimSpace(predictionText)
if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") {
predictionText = strings.TrimPrefix(predictionText, "```json")
predictionText = strings.TrimSuffix(predictionText, "```")
// JSONContentParser parses the response as JSON string format
type JSONContentParser struct {
systemPrompt string
}
func (p *JSONContentParser) SystemPrompt() string {
return p.systemPrompt
}
func (p *JSONContentParser) Parse(content string, size types.Size) (*PlanningResult, error) {
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```json") && strings.HasSuffix(content, "```") {
content = strings.TrimPrefix(content, "```json")
content = strings.TrimSuffix(content, "```")
}
predictionText = strings.TrimSpace(predictionText)
content = strings.TrimSpace(content)
var response PlanningResult
if err := json.Unmarshal([]byte(predictionText), &response); err != nil {
if err := json.Unmarshal([]byte(content), &response); err != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
}
@@ -352,7 +403,10 @@ func parseJSON(predictionText string) ([]ParsedAction, error) {
normalizedActions = append(normalizedActions, action)
}
return normalizedActions, nil
return &PlanningResult{
NextActions: normalizedActions,
ActionSummary: response.ActionSummary,
}, nil
}
// normalizeAction normalizes the coordinates in the action
@@ -379,3 +433,50 @@ func normalizeAction(action *ParsedAction) error {
return nil
}
// DefaultContentParser parses the response as string format
type DefaultContentParser struct {
systemPrompt string
}
func (p *DefaultContentParser) SystemPrompt() string {
return p.systemPrompt
}
func (p *DefaultContentParser) Parse(content string, size types.Size) (*PlanningResult, error) {
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```json") && strings.HasSuffix(content, "```") {
content = strings.TrimPrefix(content, "```json")
content = strings.TrimSuffix(content, "```")
}
content = strings.TrimSpace(content)
var response PlanningResult
if err := json.Unmarshal([]byte(content), &response); err != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
}
if response.Error != "" {
return nil, errors.New(response.Error)
}
if len(response.NextActions) == 0 {
return nil, errors.New("no actions returned from VLM")
}
// normalize actions
var normalizedActions []ParsedAction
for i := range response.NextActions {
// create a new variable, avoid implicit memory aliasing in for loop.
action := response.NextActions[i]
if err := normalizeAction(&action); err != nil {
return nil, errors.Wrap(err, "failed to normalize action")
}
normalizedActions = append(normalizedActions, action)
}
return &PlanningResult{
NextActions: normalizedActions,
ActionSummary: response.ActionSummary,
}, nil
}

View File

@@ -1,6 +1,21 @@
package ai
import (
"fmt"
"os"
)
// Constants for log fields
const (
vlCoTLog = `"what_the_user_wants_to_do_next_by_instruction": string, // What the user wants to do according to the instruction and previous logs.`
vlCurrentLog = `"log": string, // Log what the next one action (ONLY ONE!) you can do according to the screenshot and the instruction. The typical log looks like "Now i want to use action '{{ action-type }}' to do .. first". If no action should be done, log the reason. ". Use the same language as the user's instruction.`
llmCurrentLog = `"log": string, // Log what the next actions you can do according to the screenshot and the instruction. The typical log looks like "Now i want to use action '{{ action-type }}' to do ..". If no action should be done, log the reason. ". Use the same language as the user's instruction.`
commonOutputFields = `"error"?: string, // Error messages about unexpected situations, if any. Only think it is an error when the situation is not expected according to the instruction. Use the same language as the user's instruction.
"more_actions_needed_by_instruction": boolean, // Consider if there is still more action(s) to do after the action in "Log" is done, according to the instruction. If so, set this field to true. Otherwise, set it to false.`
)
// https://www.volcengine.com/docs/82379/1536429
// system prompt for UITARSContentParser
const uiTarsPlanningPrompt = `
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
@@ -28,4 +43,319 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
## User Instruction
`
const defaultPlanningResponseJsonFormat = ``
// system prompt for JSONContentParser
const defaultPlanningResponseJsonFormat = `## Role
You are a versatile professional in software UI automation. Your outstanding contributions will impact the user experience of billions of users.
## Objective
- Decompose the instruction user asked into a series of actions
- Locate the target element if possible
- If the instruction cannot be accomplished, give a further plan.
## Workflow
1. Receive the screenshot, element description of screenshot(if any), user's instruction and previous logs.
2. Decompose the user's task into a sequence of actions, and place it in the "actions" field. There are different types of actions (Tap / Hover / Input / KeyboardPress / Scroll / FalsyConditionStatement / Sleep).
3. Precisely locate the target element if it's already shown in the screenshot, put the location info in the "locate" field of the action.
4. If some target elements is not shown in the screenshot, consider the user's instruction is not feasible on this page. Follow the next steps.
5. Consider whether the user's instruction will be accomplished after all the actions
- If yes, set "taskWillBeAccomplished" to true
- If no, don't plan more actions by closing the array. Get ready to reevaluate the task. Some talent people like you will handle this. Give him a clear description of what have been done and what to do next. Put your new plan in the "furtherPlan" field.
## Constraints
- All the actions you composed MUST be based on the page context information you get.
- Trust the "What have been done" field about the task (if any), don't repeat actions in it.
- Respond only with valid JSON. Do not write an introduction or summary or markdown prefix like ` + "```" + `json` + "```" + `.
- If the screenshot and the instruction are totally irrelevant, set reason in the "error" field.
## About the "actions" field
The "locate" param is commonly used in the "param" field of the action, means to locate the target element to perform the action, it conforms to the following scheme:
type LocateParam = {
"id": string, // the id of the element found. It should either be the id marked with a rectangle in the screenshot or the id described in the description.
"prompt"?: string // the description of the element to find. It can only be omitted when locate is null
} | null // If it's not on the page, the LocateParam should be null
## Supported actions
Each action has a "type" and corresponding "param". To be detailed:
- type: 'Tap'
* { locate: {id: string, prompt: string} | null }
- type: 'Hover'
* { locate: {id: string, prompt: string} | null }
- type: 'Input', replace the value in the input field
* { locate: {id: string, prompt: string} | null, param: { value: string } }
* "value" is the final value that should be filled in the input field. No matter what modifications are required, just provide the final value user should see after the action is done.
- type: 'KeyboardPress', press a key
* { param: { value: string } }
- type: 'Scroll', scroll up or down.
* {
locate: {id: string, prompt: string} | null,
param: {
direction: 'down'(default) | 'up' | 'right' | 'left',
scrollType: 'once' (default) | 'untilBottom' | 'untilTop' | 'untilRight' | 'untilLeft',
distance: null | number
}
}
* To scroll some specific element, put the element at the center of the region in the "locate" field. If it's a page scroll, put "null" in the "locate" field.
* "param" is required in this action. If some fields are not specified, use direction "down", "once" scroll type, and "null" distance.
* { param: { button: 'Back' | 'Home' | 'RecentApp' } }
- type: 'ExpectedFalsyCondition'
* { param: { reason: string } }
* use this action when the conditional statement talked about in the instruction is falsy.
- type: 'Sleep'
* { param: { timeMs: number } }
## Output JSON Format:
The JSON format is as follows:
{
"actions": [
// ... some actions
],
"log": "string, // Log what these planned actions do. Do not include further actions that have not been planned",
"error": "string | null, // Error messages about unexpected situations",
"more_actions_needed_by_instruction": "boolean // If all the actions described in the instruction have been covered by this action and logs, set this field to false"
}
## Examples
### Example: Decompose a task
When the instruction is 'Click the language switch button, wait 1s, click "English"', and not log is provided
By viewing the page screenshot and description, you should consider this and output the JSON:
* The main steps should be: tap the switch button, sleep, and tap the 'English' option
* The language switch button is shown in the screenshot, but it's not marked with a rectangle. So we have to use the page description to find the element. By carefully checking the context information (coordinates, attributes, content, etc.), you can find the element.
* The "English" option button is not shown in the screenshot now, it means it may only show after the previous actions are finished. So don't plan any action to do this.
* Log what these action do: Click the language switch button to open the language options. Wait for 1 second.
* The task cannot be accomplished (because we cannot see the "English" option now), so the "more_actions_needed_by_instruction" field is true.
{
"actions":[
{
"type": "Tap",
"thought": "Click the language switch button to open the language options.",
"param": null,
"locate": { id: "c81c4e9a33", prompt: "The language switch button" },
},
{
"type": "Sleep",
"thought": "Wait for 1 second to ensure the language options are displayed.",
"param": { "timeMs": 1000 },
}
],
"error": null,
"more_actions_needed_by_instruction": true,
"log": "Click the language switch button to open the language options. Wait for 1 second",
}
### Example: What NOT to do
Wrong output:
{
"actions":[
{
"type": "Tap",
"thought": "Click the language switch button to open the language options.",
"param": null,
"locate": {
{ "id": "c81c4e9a33" }, // WRONG: prompt is missing
}
},
{
"type": "Tap",
"thought": "Click the English option",
"param": null,
"locate": null, // This means the 'English' option is not shown in the screenshot, the task cannot be accomplished
}
],
"more_actions_needed_by_instruction": false, // WRONG: should be true
"log": "Click the language switch button to open the language options",
}
Reason:
* The "prompt" is missing in the first 'Locate' action
* Since the option button is not shown in the screenshot, there are still more actions to be done, so the "more_actions_needed_by_instruction" field should be true`
// PlanSchema defines the JSON schema for the plan
type PlanSchema struct {
Type string `json:"type"`
JSONSchema struct {
Name string `json:"name"`
Strict bool `json:"strict"`
Schema struct {
Type string `json:"type"`
Strict bool `json:"strict"`
Properties struct {
Actions struct {
Type string `json:"type"`
Items struct {
Type string `json:"type"`
Strict bool `json:"strict"`
Properties struct {
Thought struct {
Type string `json:"type"`
Description string `json:"description"`
} `json:"thought"`
Type struct {
Type string `json:"type"`
Description string `json:"description"`
} `json:"type"`
Param struct {
AnyOf []struct {
Type string `json:"type,omitempty"`
Properties struct {
Value struct {
Type []string `json:"type"`
} `json:"value,omitempty"`
TimeMs struct {
Type []string `json:"type"`
} `json:"timeMs,omitempty"`
Direction struct {
Type string `json:"type"`
} `json:"direction,omitempty"`
ScrollType struct {
Type string `json:"type"`
} `json:"scrollType,omitempty"`
Distance struct {
Type []string `json:"type"`
} `json:"distance,omitempty"`
Reason struct {
Type string `json:"type"`
} `json:"reason,omitempty"`
Button struct {
Type string `json:"type"`
} `json:"button,omitempty"`
} `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
AdditionalProperties bool `json:"additionalProperties,omitempty"`
} `json:"anyOf"`
Description string `json:"description"`
} `json:"param"`
Locate struct {
Type []string `json:"type"`
Properties struct {
ID struct {
Type string `json:"type"`
} `json:"id"`
Prompt struct {
Type string `json:"type"`
} `json:"prompt"`
} `json:"properties"`
Required []string `json:"required"`
AdditionalProperties bool `json:"additionalProperties"`
Description string `json:"description"`
} `json:"locate"`
} `json:"properties"`
Required []string `json:"required"`
AdditionalProperties bool `json:"additionalProperties"`
} `json:"items"`
Description string `json:"description"`
} `json:"actions"`
MoreActionsNeededByInstruction struct {
Type string `json:"type"`
Description string `json:"description"`
} `json:"more_actions_needed_by_instruction"`
Log struct {
Type string `json:"type"`
Description string `json:"description"`
} `json:"log"`
Error struct {
Type []string `json:"type"`
Description string `json:"description"`
} `json:"error"`
} `json:"properties"`
Required []string `json:"required"`
AdditionalProperties bool `json:"additionalProperties"`
} `json:"schema"`
} `json:"json_schema"`
}
// GetPlanningResponseJsonFormat returns the planning response format based on page type
func GetPlanningResponseJsonFormat(pageType string) string {
if pageType == "android" {
return defaultPlanningResponseJsonFormat + `
- type: 'AndroidBackButton', trigger the system "back" operation on Android devices
* { param: {} }
- type: 'AndroidHomeButton', trigger the system "home" operation on Android devices
* { param: {} }
- type: 'AndroidRecentAppsButton', trigger the system "recent apps" operation on Android devices
* { param: {} }`
}
return defaultPlanningResponseJsonFormat
}
// GenerateTaskBackgroundContext generates the task background context
func GenerateTaskBackgroundContext(userInstruction string, log string, userActionContext string) string {
if log != "" {
return fmt.Sprintf(`
Here is the user's instruction:
<instruction>
<high_priority_knowledge>
%s
</high_priority_knowledge>
%s
</instruction>
These are the logs from previous executions, which indicate what was done in the previous actions.
Do NOT repeat these actions.
<previous_logs>
%s
</previous_logs>
`, userActionContext, userInstruction, log)
}
return fmt.Sprintf(`
Here is the user's instruction:
<instruction>
<high_priority_knowledge>
%s
</high_priority_knowledge>
%s
</instruction>
`, userActionContext, userInstruction)
}
// AutomationUserPrompt generates the automation user prompt
func AutomationUserPrompt(vlMode bool, pageDescription string, taskBackgroundContext string) string {
if vlMode {
return taskBackgroundContext
}
return fmt.Sprintf(`
pageDescription:
=====================================
%s
=====================================
%s`, pageDescription, taskBackgroundContext)
}
const defaultPlanningResponseStringFormat = `
You are a helpful assistant.
`
// loadSystemPrompt loads the system prompt from a JSON file
func loadSystemPrompt(filePath string) (string, error) {
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("system prompt file does not exist: %s", filePath)
}
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("error reading prompt file: %v", err)
}
// Read file content directly as prompt
return string(data), nil
}

View File

@@ -53,7 +53,7 @@ func TestVLMPlanning(t *testing.T) {
}
// 执行规划
result, err := planner.Call(opts)
result, err := planner.Call(context.Background(), opts)
// 验证结果
require.NoError(t, err)
@@ -126,7 +126,7 @@ func TestXHSPlanning(t *testing.T) {
}
// 执行规划
result, err := planner.Call(opts)
result, err := planner.Call(context.Background(), opts)
// 验证结果
require.NoError(t, err)
@@ -199,7 +199,7 @@ func TestChatList(t *testing.T) {
}
// 执行规划
result, err := planner.Call(opts)
result, err := planner.Call(context.Background(), opts)
// 验证结果
require.NoError(t, err)
@@ -246,7 +246,7 @@ func TestHandleSwitch(t *testing.T) {
}
// Execute planning
result, err := planner.Call(opts)
result, err := planner.Call(context.Background(), opts)
// Validate results
require.NoError(t, err)

View File

@@ -44,7 +44,7 @@ func (h *ConversationHistory) Append(msg *schema.Message) {
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
if msg.Role == schema.Assistant || msg.Role == schema.Tool {
// add the new assistant message to the history
*h = append(*h, msg)
@@ -61,6 +61,13 @@ func (h *ConversationHistory) Append(msg *schema.Message) {
}
}
func (h *ConversationHistory) Clear() {
// Keep only the system message
systemMsg := (*h)[0]
*h = ConversationHistory{systemMsg}
log.Info().Msg("conversation history cleared")
}
func logRequest(messages ConversationHistory) {
msgs := make(ConversationHistory, 0, len(messages))
for _, message := range messages {
@@ -94,7 +101,13 @@ func logResponse(resp *schema.Message) {
logger := log.Info().Str("role", string(resp.Role)).
Str("content", resp.Content)
if resp.ResponseMeta != nil {
logger = logger.Interface("response_meta", resp.ResponseMeta)
logger = logger.Str("finish_reason", resp.ResponseMeta.FinishReason)
// Log usage statistics
if usage := resp.ResponseMeta.Usage; usage != nil {
log.Debug().Int("input_tokens", usage.PromptTokens).
Int("output_tokens", usage.CompletionTokens).
Int("total_tokens", usage.TotalTokens).Msg("usage statistics")
}
}
if resp.Extra != nil {
logger = logger.Interface("extra", resp.Extra)

View File

@@ -1,6 +1,7 @@
package uixt
import (
"context"
"encoding/base64"
"fmt"
"path/filepath"
@@ -102,7 +103,7 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
Size: size,
}
result, err := dExt.LLMService.Call(planningOpts)
result, err := dExt.LLMService.Call(context.Background(), planningOpts)
if err != nil {
return nil, errors.Wrap(err, "failed to get next action from planner")
}
@@ -139,7 +140,7 @@ func (dExt *XTDriver) AIAssert(assertion string, opts ...option.ActionOption) er
Screenshot: screenShotBase64,
Size: size,
}
result, err := dExt.LLMService.Assert(assertOpts)
result, err := dExt.LLMService.Assert(context.Background(), assertOpts)
if err != nil {
return errors.Wrap(err, "AI assertion failed")
}