refactor: llm planner

This commit is contained in:
lilong.129
2025-04-21 20:51:35 +08:00
parent 938a5e9475
commit 70a8ee01f7
10 changed files with 414 additions and 717 deletions

View File

@@ -56,7 +56,7 @@ func WithLLMService(service LLMServiceType) AIServiceOption {
return func(opts *AIServices) {
if service == LLMServiceTypeGPT4o {
var err error
opts.ILLMService, err = NewGPT4oLLMService()
opts.ILLMService, err = NewPlanner(context.Background())
if err != nil {
log.Error().Err(err).Msg("init gpt-4o llm service failed")
os.Exit(code.GetErrorCode(err))
@@ -64,9 +64,10 @@ func WithLLMService(service LLMServiceType) AIServiceOption {
}
if service == LLMServiceTypeUITARS {
var err error
opts.ILLMService, err = NewPlanner(context.Background())
opts.ILLMService, err = NewUITarsPlanner(context.Background())
if err != nil {
log.Error().Err(err).Msg("init ui-tars llm service failed")
os.Exit(code.GetErrorCode(err))
}
}
}

View File

@@ -1,211 +0,0 @@
package ai
import (
"fmt"
"net/http"
"os"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/schema"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
)
type ILLMService interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
}
func NewGPT4oLLMService() (*openaiLLMService, error) {
return &openaiLLMService{}, nil
}
type openaiLLMService struct{}
func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
return nil, nil
}
// PlanningOptions represents the input options for planning
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"` // append to system prompt
Message *schema.Message `json:"message"`
Size types.Size `json:"size"`
}
// PlanningResult represents the result of planning
type PlanningResult struct {
NextActions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
}
// VLMResponse represents the response from the Vision Language Model
type VLMResponse struct {
Actions []ParsedAction `json:"actions"`
Error string `json:"error,omitempty"`
}
// ParsedAction represents a parsed action from the VLM response
type ParsedAction struct {
ActionType ActionType `json:"actionType"`
ActionInputs map[string]interface{} `json:"actionInputs"`
Thought string `json:"thought"`
}
type ActionType string
const (
ActionTypeClick ActionType = "click"
ActionTypeTap ActionType = "tap"
ActionTypeDrag ActionType = "drag"
ActionTypeSwipe ActionType = "swipe"
ActionTypeWait ActionType = "wait"
ActionTypeFinished ActionType = "finished"
ActionTypeCallUser ActionType = "call_user"
ActionTypeType ActionType = "type"
ActionTypeScroll ActionType = "scroll"
)
const (
defaultTimeout = 60 * time.Second
)
type OpenAIInitConfig struct {
ReportURL string `json:"REPORT_SERVER_URL"`
Headers map[string]string `json:"defaultHeaders"`
}
const (
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
EnvOpenAIAPIKey = "OPENAI_API_KEY"
EnvModelName = "LLM_MODEL_NAME"
EnvOpenAIInitConfigJSON = "OPENAI_INIT_CONFIG_JSON"
)
func checkEnvLLM() error {
if err := config.LoadEnv(); err != nil {
return errors.Wrap(code.LoadEnvError, err.Error())
}
openaiBaseURL := os.Getenv("OPENAI_BASE_URL")
if openaiBaseURL == "" {
return errors.Wrap(code.LLMEnvMissedError, "OPENAI_BASE_URL missed")
}
log.Info().Str("OPENAI_BASE_URL", openaiBaseURL).Msg("get env")
openaiAPIKey := os.Getenv("OPENAI_API_KEY")
if openaiAPIKey == "" {
return errors.Wrap(code.LLMEnvMissedError, "OPENAI_API_KEY missed")
}
log.Info().Str("OPENAI_API_KEY", maskAPIKey(openaiAPIKey)).Msg("get env")
modelName := os.Getenv("LLM_MODEL_NAME")
if modelName == "" {
return errors.Wrap(code.LLMEnvMissedError, "LLM_MODEL_NAME missed")
}
log.Info().Str("LLM_MODEL_NAME", modelName).Msg("get env")
return nil
}
// CustomTransport is a custom RoundTripper that adds headers to every request
type CustomTransport struct {
Transport http.RoundTripper
Headers map[string]string
}
// RoundTrip executes a single HTTP transaction and adds custom headers
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range c.Headers {
req.Header.Set(key, value)
}
return c.Transport.RoundTrip(req)
}
type OutputFormat struct {
Thought string `json:"thought"`
Action string `json:"action"`
Error string `json:"error,omitempty"`
}
// GetModelConfig get OpenAI config
func GetModelConfig() (*openai.ChatModelConfig, error) {
if err := checkEnvLLM(); err != nil {
log.Error().Err(err).Msg("check LLM env failed")
return nil, err
}
envConfig := &OpenAIInitConfig{
Headers: make(map[string]string),
}
// read from JSON config first
jsonStr := config.GetEnvConfig(EnvOpenAIInitConfigJSON)
if jsonStr != "" {
if err := json.Unmarshal([]byte(jsonStr), envConfig); err != nil {
return nil, err
}
}
// outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
// if err != nil {
// log.Fatal().Err(err).Msg("NewSchemaRefForValue failed")
// }
modelConfig := &openai.ChatModelConfig{
HTTPClient: &http.Client{
Timeout: defaultTimeout,
Transport: &CustomTransport{
Transport: http.DefaultTransport,
Headers: envConfig.Headers,
},
},
// TODO: set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
// ResponseFormat: &openai2.ChatCompletionResponseFormat{
// Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
// JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
// Name: "thought_and_action",
// Description: "data that describes planning thought and action",
// Schema: outputFormatSchema.Value,
// Strict: false,
// },
// },
}
if baseURL := config.GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
modelConfig.BaseURL = baseURL
} else {
return nil, fmt.Errorf("miss env %s", EnvOpenAIBaseURL)
}
if apiKey := config.GetEnvConfig(EnvOpenAIAPIKey); apiKey != "" {
modelConfig.APIKey = apiKey
} else {
return nil, fmt.Errorf("miss env %s", EnvOpenAIAPIKey)
}
if modelName := config.GetEnvConfig(EnvModelName); modelName != "" {
modelConfig.Model = modelName
} else {
return nil, fmt.Errorf("miss env %s", EnvModelName)
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {
return "******"
}
return key[:4] + "******" + key[len(key)-4:]
}

View File

@@ -1,267 +0,0 @@
package ai
import (
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/pkg/errors"
)
// NewActionParser creates a new ActionParser instance
func NewActionParser(factor float64) *ActionParser {
return &ActionParser{
Factor: factor,
}
}
// ActionParser parses VLM responses and converts them to structured actions
type ActionParser struct {
Factor float64 // TODO
}
// Parse parses the prediction text and extracts actions
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
// try parsing JSON format, from VLM like openai/gpt-4o
jsonActions, jsonErr := p.parseJSON(predictionText)
if jsonErr == nil {
return jsonActions, nil
}
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
if thoughtErr == nil {
return thoughtActions, nil
}
return nil, errors.Wrap(thoughtErr, "parse planner response failed")
}
// parseJSON tries to parse the response as JSON format
func (p *ActionParser) parseJSON(predictionText string) ([]ParsedAction, error) {
predictionText = strings.TrimSpace(predictionText)
if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") {
predictionText = strings.TrimPrefix(predictionText, "```json")
predictionText = strings.TrimSuffix(predictionText, "```")
}
predictionText = strings.TrimSpace(predictionText)
var response VLMResponse
if err := json.Unmarshal([]byte(predictionText), &response); err != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
}
if response.Error != "" {
return nil, errors.New(response.Error)
}
if len(response.Actions) == 0 {
return nil, errors.New("no actions returned from VLM")
}
// normalize actions
var normalizedActions []ParsedAction
for i := range response.Actions {
// create a new variable, avoid implicit memory aliasing in for loop.
action := response.Actions[i]
if err := p.normalizeAction(&action); err != nil {
return nil, errors.Wrap(err, "failed to normalize action")
}
normalizedActions = append(normalizedActions, action)
}
return normalizedActions, nil
}
// parseThoughtAction parses the Thought/Action format response
func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction, error) {
thoughtRegex := regexp.MustCompile(`(?is)Thought:(.+?)Action:`)
actionRegex := regexp.MustCompile(`(?is)Action:(.+)`)
// extract Thought part
thoughtMatch := thoughtRegex.FindStringSubmatch(predictionText)
var thought string
if len(thoughtMatch) > 1 {
thought = strings.TrimSpace(thoughtMatch[1])
}
// extract Action part, e.g. "click(start_box='(552,454)')"
actionMatch := actionRegex.FindStringSubmatch(predictionText)
if len(actionMatch) < 2 {
return nil, errors.New("no action found in the response")
}
actionsText := strings.TrimSpace(actionMatch[1])
// parse action type and parameters
return p.parseActionText(actionsText, thought)
}
// parseActionText parses the action text to extract the action type and parameters
func (p *ActionParser) parseActionText(actionsText, thought string) ([]ParsedAction, error) {
// remove trailing comments
if idx := strings.Index(actionsText, "#"); idx > 0 {
actionsText = strings.TrimSpace(actionsText[:idx])
}
// supported action types and regexes
actionRegexes := map[ActionType]*regexp.Regexp{
"click": regexp.MustCompile(`click\(start_box='([^']+)'\)`),
"left_double": regexp.MustCompile(`left_double\(start_box='([^']+)'\)`),
"right_single": regexp.MustCompile(`right_single\(start_box='([^']+)'\)`),
"drag": regexp.MustCompile(`drag\(start_box='([^']+)', end_box='([^']+)'\)`),
"type": regexp.MustCompile(`type\(content='([^']+)'\)`),
"scroll": regexp.MustCompile(`scroll\(start_box='([^']+)', direction='([^']+)'\)`),
"wait": regexp.MustCompile(`wait\(\)`),
"finished": regexp.MustCompile(`finished\(content='([^']+)'\)`),
"call_user": regexp.MustCompile(`call_user\(\)`),
}
// one or multiple actions, separated by newline
// "click(start_box='<bbox>229 379 229 379</bbox>')
// "click(start_box='<bbox>229 379 229 379</bbox>')\n\nclick(start_box='<bbox>769 519 769 519</bbox>')"
parsedActions := make([]ParsedAction, 0)
for _, actionText := range strings.Split(actionsText, "\n") {
actionText = strings.TrimSpace(actionText)
for actionType, regex := range actionRegexes {
matches := regex.FindStringSubmatch(actionText)
if len(matches) == 0 {
continue
}
var action ParsedAction
action.ActionType = actionType
action.ActionInputs = make(map[string]interface{})
action.Thought = thought
// parse parameters based on action type
switch actionType {
case ActionTypeClick:
if len(matches) > 1 {
coord, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize point failed: %s", matches[1])
}
action.ActionInputs["startBox"] = coord
}
case ActionTypeDrag:
if len(matches) > 2 {
// handle start point
startBox, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
}
action.ActionInputs["startBox"] = startBox
// handle end point
endBox, err := p.normalizeCoordinates(matches[2])
if err != nil {
return nil, errors.Wrapf(err, "normalize endBox failed: %s", matches[2])
}
action.ActionInputs["endBox"] = endBox
}
case ActionTypeType:
if len(matches) > 1 {
action.ActionInputs["content"] = matches[1]
}
case ActionTypeScroll:
if len(matches) > 2 {
startBox, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
}
action.ActionInputs["startBox"] = startBox
action.ActionInputs["direction"] = matches[2]
}
case ActionTypeWait, ActionTypeFinished, ActionTypeCallUser:
// 这些动作没有额外参数
}
parsedActions = append(parsedActions, action)
}
}
if len(parsedActions) == 0 {
return nil, fmt.Errorf("no valid actions returned from VLM")
}
return parsedActions, nil
}
// normalizeAction normalizes the coordinates in the action
func (p *ActionParser) normalizeAction(action *ParsedAction) error {
switch action.ActionType {
case "click", "drag":
// handle click and drag action coordinates
if startBox, ok := action.ActionInputs["startBox"].(string); ok {
normalized, err := p.normalizeCoordinates(startBox)
if err != nil {
return fmt.Errorf("failed to normalize startBox: %w", err)
}
action.ActionInputs["startBox"] = normalized
}
if endBox, ok := action.ActionInputs["endBox"].(string); ok {
normalized, err := p.normalizeCoordinates(endBox)
if err != nil {
return fmt.Errorf("failed to normalize endBox: %w", err)
}
action.ActionInputs["endBox"] = normalized
}
}
return nil
}
// normalizeCoordinates normalizes the coordinates based on the factor
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
// check empty string
if coordStr == "" {
return nil, fmt.Errorf("empty coordinate string")
}
// handle BBox format: <bbox>x1 y1 x2 y2</bbox>
bboxRegex := regexp.MustCompile(`<bbox>(\d+\s+\d+\s+\d+\s+\d+)</bbox>`)
bboxMatches := bboxRegex.FindStringSubmatch(coordStr)
if len(bboxMatches) > 1 {
// Extract space-separated values from inside the bbox tags
bboxContent := bboxMatches[1]
// Split by whitespace
parts := strings.Fields(bboxContent)
if len(parts) == 4 {
coords = make([]float64, 4)
for i, part := range parts {
val, e := strconv.ParseFloat(part, 64)
if e != nil {
return nil, fmt.Errorf("failed to parse coordinate value '%s': %w", part, e)
}
coords[i] = val
}
// 将 val 转换为 [x,y] 坐标
x := (coords[0] + coords[2]) / 2
y := (coords[1] + coords[3]) / 2
return []float64{x, y}, nil
}
}
// handle coordinate string, e.g. "[100, 200]", "(100, 200)"
if strings.Contains(coordStr, ",") {
// remove possible brackets and split coordinates
coordStr = strings.Trim(coordStr, "[]() \t")
// try parsing JSON array
jsonStr := coordStr
if !strings.HasPrefix(jsonStr, "[") {
jsonStr = "[" + coordStr + "]"
}
err = json.Unmarshal([]byte(jsonStr), &coords)
if err != nil {
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
}
return coords, nil
}
return nil, fmt.Errorf("invalid coordinate string format: %s", coordStr)
}

View File

@@ -2,26 +2,64 @@ package ai
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
_ "image/jpeg"
"image/png"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
type ILLMService interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
}
// PlanningOptions represents the input options for planning
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"` // append to system prompt
Message *schema.Message `json:"message"`
Size types.Size `json:"size"`
}
// PlanningResult represents the result of planning
type PlanningResult struct {
NextActions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
Error string `json:"error,omitempty"`
}
// ParsedAction represents a parsed action from the VLM response
type ParsedAction struct {
ActionType ActionType `json:"actionType"`
ActionInputs map[string]interface{} `json:"actionInputs"`
Thought string `json:"thought"`
}
type ActionType string
const (
ActionTypeClick ActionType = "click"
ActionTypeTap ActionType = "tap"
ActionTypeDrag ActionType = "drag"
ActionTypeSwipe ActionType = "swipe"
ActionTypeWait ActionType = "wait"
ActionTypeFinished ActionType = "finished"
ActionTypeCallUser ActionType = "call_user"
ActionTypeType ActionType = "type"
ActionTypeScroll ActionType = "scroll"
)
const (
defaultTimeout = 60 * time.Second
)
// Error types
var (
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
@@ -29,81 +67,6 @@ var (
ErrInvalidImageData = fmt.Errorf("invalid image data")
)
func NewPlanner(ctx context.Context) (*Planner, error) {
config, err := GetModelConfig()
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
}
model, err := openai.NewChatModel(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
}
parser := NewActionParser(1000)
return &Planner{
ctx: ctx,
config: config,
model: model,
systemPrompt: uiTarsPlanningPrompt,
parser: parser,
}, nil
}
type Planner struct {
ctx context.Context
model model.ChatModel
config *openai.ChatModelConfig
systemPrompt string
parser *ActionParser
history []*schema.Message // conversation history
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
}
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
}
}
// append user image message
p.appendConversationHistory(opts.Message)
// call model service, generate response
logRequest(p.history)
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", p.config.Model).Msg("call model service")
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
}
logResponse(resp)
// parse result
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
}
// append assistant message
p.appendConversationHistory(&schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})
return result, nil
}
func validateInput(opts *PlanningOptions) error {
if opts.UserInstruction == "" {
return ErrEmptyInstruction
@@ -169,7 +132,7 @@ func logResponse(resp *schema.Message) {
}
// appendConversationHistory adds a message to the conversation history
func (p *Planner) appendConversationHistory(msg *schema.Message) {
func appendConversationHistory(history []*schema.Message, msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
@@ -179,7 +142,7 @@ func (p *Planner) appendConversationHistory(msg *schema.Message) {
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range p.history {
for i, item := range history {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
@@ -191,54 +154,34 @@ func (p *Planner) appendConversationHistory(msg *schema.Message) {
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
p.history = append(
p.history[:firstUserImgIndex],
p.history[firstUserImgIndex+1:]...,
history = append(
history[:firstUserImgIndex],
history[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
p.history = append(p.history, msg)
history = append(history, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
p.history = append(p.history, msg)
history = append(history, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(p.history) - 1; i >= 0; i-- {
if p.history[i].Role == schema.Assistant {
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
p.history = append(p.history[:i], p.history[i+1:]...)
history = append(history[:i], history[i+1:]...)
}
}
}
}
}
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
// parse response
parseActions, err := p.parser.Parse(msg.Content)
if err != nil {
return nil, fmt.Errorf("failed to parse actions: %w", err)
}
// process response
result, err := processVLMResponse(parseActions, size)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
return result, nil
}
// SavePositionImg saves an image with position markers
func SavePositionImg(params struct {
InputImgBase64 string
@@ -336,3 +279,12 @@ func loadImage(imagePath string) (base64Str string, size types.Size, err error)
return base64Str, size, nil
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {
return "******"
}
return key[:4] + "******" + key[len(key)-4:]
}

242
uixt/ai/planner_gpt.go Normal file
View File

@@ -0,0 +1,242 @@
package ai
import (
"context"
"fmt"
_ "image/jpeg"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3gen"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
)
const (
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
EnvOpenAIAPIKey = "OPENAI_API_KEY"
EnvModelName = "LLM_MODEL_NAME"
)
// GetOpenAIModelConfig get OpenAI config
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
if openaiBaseURL == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIBaseURL)
}
openaiAPIKey := os.Getenv(EnvOpenAIAPIKey)
if openaiAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvOpenAIAPIKey)
}
modelName := os.Getenv(EnvModelName)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvModelName)
}
type OutputFormat struct {
Thought string `json:"thought"`
Action string `json:"action"`
Error string `json:"error,omitempty"`
}
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
if err != nil {
return nil, err
}
modelConfig := &openai.ChatModelConfig{
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
// set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
ResponseFormat: &openai2.ChatCompletionResponseFormat{
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
Name: "thought_and_action",
Description: "data that describes planning thought and action",
Schema: outputFormatSchema.Value,
Strict: false,
},
},
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}
func NewPlanner(ctx context.Context) (*Planner, error) {
config, err := GetOpenAIModelConfig()
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
}
model, err := openai.NewChatModel(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
}
return &Planner{
ctx: ctx,
config: config,
model: model,
systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling
}, nil
}
type Planner struct {
ctx context.Context
model model.ToolCallingChatModel
config *openai.ChatModelConfig
systemPrompt string
history []*schema.Message // conversation history
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
}
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
}
}
// append user image message
appendConversationHistory(p.history, opts.Message)
// call model service, generate response
logRequest(p.history)
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", p.config.Model).Msg("call model service")
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
}
logResponse(resp)
// parse result
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
}
// append assistant message
appendConversationHistory(p.history, &schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})
return result, nil
}
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
// parse JSON format, from VLM like openai/gpt-4o
parseActions, jsonErr := parseJSON(msg.Content)
if jsonErr != nil {
return nil, jsonErr
}
// process response
result, err := processVLMResponse(parseActions, size)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
return result, nil
}
// parseJSON tries to parse the response as JSON format
func parseJSON(predictionText string) ([]ParsedAction, error) {
predictionText = strings.TrimSpace(predictionText)
if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") {
predictionText = strings.TrimPrefix(predictionText, "```json")
predictionText = strings.TrimSuffix(predictionText, "```")
}
predictionText = strings.TrimSpace(predictionText)
var response PlanningResult
if err := json.Unmarshal([]byte(predictionText), &response); err != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
}
if response.Error != "" {
return nil, errors.New(response.Error)
}
if len(response.NextActions) == 0 {
return nil, errors.New("no actions returned from VLM")
}
// normalize actions
var normalizedActions []ParsedAction
for i := range response.NextActions {
// create a new variable, avoid implicit memory aliasing in for loop.
action := response.NextActions[i]
if err := normalizeAction(&action); err != nil {
return nil, errors.Wrap(err, "failed to normalize action")
}
normalizedActions = append(normalizedActions, action)
}
return normalizedActions, nil
}
// normalizeAction normalizes the coordinates in the action
func normalizeAction(action *ParsedAction) error {
switch action.ActionType {
case "click", "drag":
// handle click and drag action coordinates
if startBox, ok := action.ActionInputs["startBox"].(string); ok {
normalized, err := normalizeCoordinates(startBox)
if err != nil {
return fmt.Errorf("failed to normalize startBox: %w", err)
}
action.ActionInputs["startBox"] = normalized
}
if endBox, ok := action.ActionInputs["endBox"].(string); ok {
normalized, err := normalizeCoordinates(endBox)
if err != nil {
return fmt.Errorf("failed to normalize endBox: %w", err)
}
action.ActionInputs["endBox"] = normalized
}
}
return nil
}

View File

@@ -28,7 +28,7 @@ func TestVLMPlanning(t *testing.T) {
userInstruction += "\n\n请基于以上游戏规则给出下一步可点击的两个图标坐标"
planner, err := NewPlanner(context.Background())
planner, err := NewUITarsPlanner(context.Background())
require.NoError(t, err)
opts := &PlanningOptions{
@@ -98,7 +98,7 @@ func TestXHSPlanning(t *testing.T) {
userInstruction := "点击第二个帖子的作者头像"
planner, err := NewPlanner(context.Background())
planner, err := NewUITarsPlanner(context.Background())
require.NoError(t, err)
opts := &PlanningOptions{
@@ -168,7 +168,7 @@ func TestChatList(t *testing.T) {
userInstruction := "请结合图片的文字信息,请告诉我一共有多少个群聊,哪些群聊右下角有绿点"
planner, err := NewPlanner(context.Background())
planner, err := NewUITarsPlanner(context.Background())
require.NoError(t, err)
opts := &PlanningOptions{

View File

@@ -13,17 +13,52 @@ import (
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
const (
EnvArkBaseURL = "ARK_BASE_URL"
EnvArkAPIKey = "ARK_API_KEY"
EnvArkModelID = "ARK_MODEL_ID"
)
func GetArkModelConfig() (*ark.ChatModelConfig, error) {
return &ark.ChatModelConfig{
APIKey: os.Getenv("ARK_API_KEY"),
Model: os.Getenv("ARK_MODEL_ID"),
}, nil
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
arkBaseURL := os.Getenv(EnvArkBaseURL)
arkAPIKey := os.Getenv(EnvArkAPIKey)
if arkAPIKey == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvArkAPIKey)
}
modelName := os.Getenv(EnvArkModelID)
if modelName == "" {
return nil, errors.Wrapf(code.LLMEnvMissedError,
"env %s missed", EnvArkModelID)
}
timeout := defaultTimeout
modelConfig := &ark.ChatModelConfig{
BaseURL: arkBaseURL,
APIKey: arkAPIKey,
Model: modelName,
Timeout: &timeout,
}
// log config info
log.Info().Str("model", modelConfig.Model).
Str("baseURL", modelConfig.BaseURL).
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
}
func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) {
@@ -113,7 +148,7 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
logResponse(resp)
// parse result
result, err := parseResult(resp, opts.Size)
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
}
@@ -127,58 +162,7 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
return result, nil
}
// appendConversationHistory adds a message to the conversation history
func appendConversationHistory(history []*schema.Message, msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
if msg.Role == schema.User {
// get all existing user messages
userImgCount := 0
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range history {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
firstUserImgIndex = i
}
}
}
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
history = append(
history[:firstUserImgIndex],
history[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
history = append(history, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
history = append(history, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
history = append(history[:i], history[i+1:]...)
}
}
}
}
}
func parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
func (p *UITarsPlanner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
// parse Thought/Action format from UI-TARS
parseActions, thoughtErr := parseThoughtAction(msg.Content)
if thoughtErr != nil {