mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-28 11:59:41 +08:00
313 lines
7.7 KiB
Go
313 lines
7.7 KiB
Go
package ai
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/base64"
|
||
"fmt"
|
||
"image"
|
||
"image/color"
|
||
"image/draw"
|
||
"image/png"
|
||
"os"
|
||
"strings"
|
||
|
||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/pkg/errors"
|
||
"github.com/rs/zerolog/log"
|
||
)
|
||
|
||
// Error types
|
||
var (
|
||
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
|
||
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
|
||
ErrInvalidImageData = fmt.Errorf("invalid image data")
|
||
)
|
||
|
||
func NewPlanner(ctx context.Context) (*Planner, error) {
|
||
config, err := GetModelConfig()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
|
||
}
|
||
model, err := openai.NewChatModel(ctx, config)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
|
||
}
|
||
parser := NewActionParser(1000)
|
||
return &Planner{
|
||
ctx: ctx,
|
||
model: model,
|
||
parser: parser,
|
||
}, nil
|
||
}
|
||
|
||
type Planner struct {
|
||
ctx context.Context
|
||
model *openai.ChatModel
|
||
parser *ActionParser
|
||
}
|
||
|
||
// Call performs UI planning using Vision Language Model
|
||
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
||
|
||
// validate input parameters
|
||
if err := validateInput(opts); err != nil {
|
||
return nil, errors.Wrap(err, "validate input parameters failed")
|
||
}
|
||
|
||
// call VLM service
|
||
resp, err := p.callVLMService(opts)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "call VLM service failed")
|
||
}
|
||
|
||
// parse result
|
||
result, err := p.parseResult(resp)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "parse result failed")
|
||
}
|
||
|
||
log.Info().
|
||
Interface("summary", result.ActionSummary).
|
||
Interface("actions", result.NextActions).
|
||
Msg("get VLM planning result")
|
||
return result, nil
|
||
}
|
||
|
||
func validateInput(opts *PlanningOptions) error {
|
||
if opts.UserInstruction == "" {
|
||
return ErrEmptyInstruction
|
||
}
|
||
|
||
if len(opts.ConversationHistory) == 0 {
|
||
return ErrNoConversationHistory
|
||
}
|
||
|
||
// ensure at least one image URL
|
||
hasImageURL := false
|
||
for _, msg := range opts.ConversationHistory {
|
||
if msg.Role == "user" {
|
||
// check MultiContent
|
||
if len(msg.MultiContent) > 0 {
|
||
for _, content := range msg.MultiContent {
|
||
if content.Type == "image_url" && content.ImageURL != nil {
|
||
hasImageURL = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if hasImageURL {
|
||
break
|
||
}
|
||
}
|
||
|
||
if !hasImageURL {
|
||
return ErrInvalidImageData
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// callVLMService makes the actual call to the VLM service
|
||
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
|
||
log.Info().Msg("calling VLM service...")
|
||
|
||
// prepare prompt
|
||
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
||
messages := []*schema.Message{
|
||
{
|
||
Role: schema.System,
|
||
Content: systemPrompt,
|
||
},
|
||
}
|
||
messages = append(messages, opts.ConversationHistory...)
|
||
|
||
// generate response
|
||
resp, err := p.model.Generate(p.ctx, messages)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
||
}
|
||
log.Info().Str("content", resp.Content).Msg("get VLM response")
|
||
return resp, nil
|
||
}
|
||
|
||
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
|
||
// parse response
|
||
actions, err := p.parser.Parse(msg.Content)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse actions: %w", err)
|
||
}
|
||
|
||
// process response
|
||
result, err := processVLMResponse(actions)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "process VLM response failed")
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// processVLMResponse processes the VLM response and converts it to PlanningResult
|
||
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
|
||
log.Info().Msg("processing VLM response...")
|
||
|
||
if len(actions) == 0 {
|
||
return nil, fmt.Errorf("no actions returned from VLM")
|
||
}
|
||
|
||
// validate and post-process each action
|
||
for i := range actions {
|
||
// validate action type
|
||
switch actions[i].ActionType {
|
||
case "click", "left_double", "right_single":
|
||
validateCoordinateAction(&actions[i], "startBox")
|
||
case "drag":
|
||
validateCoordinateAction(&actions[i], "startBox")
|
||
validateCoordinateAction(&actions[i], "endBox")
|
||
case "type":
|
||
validateTypeContent(&actions[i])
|
||
case "wait", "finished", "call_user":
|
||
// these actions do not need extra parameters
|
||
default:
|
||
log.Printf("warning: unknown action type: %s, will try to continue processing", actions[i].ActionType)
|
||
}
|
||
}
|
||
|
||
// extract action summary
|
||
actionSummary := extractActionSummary(actions)
|
||
|
||
return &PlanningResult{
|
||
NextActions: actions,
|
||
ActionSummary: actionSummary,
|
||
}, nil
|
||
}
|
||
|
||
// extractActionSummary extracts the summary from the actions
|
||
func extractActionSummary(actions []ParsedAction) string {
|
||
if len(actions) == 0 {
|
||
return ""
|
||
}
|
||
|
||
// use the Thought of the first action as summary
|
||
if actions[0].Thought != "" {
|
||
return actions[0].Thought
|
||
}
|
||
|
||
// if no Thought, generate summary from action type
|
||
action := actions[0]
|
||
switch action.ActionType {
|
||
case "click":
|
||
return "点击操作"
|
||
case "drag":
|
||
return "拖拽操作"
|
||
case "type":
|
||
content, _ := action.ActionInputs["content"].(string)
|
||
if len(content) > 20 {
|
||
content = content[:20] + "..."
|
||
}
|
||
return fmt.Sprintf("输入文本: %s", content)
|
||
case "wait":
|
||
return "等待操作"
|
||
case "finished":
|
||
return "完成操作"
|
||
case "call_user":
|
||
return "请求用户协助"
|
||
default:
|
||
return fmt.Sprintf("执行 %s 操作", action.ActionType)
|
||
}
|
||
}
|
||
|
||
// validateCoordinateAction 验证坐标类动作
|
||
func validateCoordinateAction(action *ParsedAction, boxField string) {
|
||
// TODO
|
||
}
|
||
|
||
// validateTypeContent 验证输入文本内容
|
||
func validateTypeContent(action *ParsedAction) {
|
||
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
|
||
// default to empty string
|
||
action.ActionInputs["content"] = ""
|
||
log.Warn().Msg("type action missing content parameter, set to default")
|
||
}
|
||
}
|
||
|
||
// SavePositionImg saves an image with position markers
|
||
func SavePositionImg(params struct {
|
||
InputImgBase64 string
|
||
Rect struct {
|
||
X float64
|
||
Y float64
|
||
}
|
||
OutputPath string
|
||
}) error {
|
||
// 解码Base64图像
|
||
imgData := params.InputImgBase64
|
||
// 如果包含了数据URL前缀,去掉它
|
||
if strings.HasPrefix(imgData, "data:image/") {
|
||
parts := strings.Split(imgData, ",")
|
||
if len(parts) > 1 {
|
||
imgData = parts[1]
|
||
}
|
||
}
|
||
|
||
// 解码Base64
|
||
unbased, err := base64.StdEncoding.DecodeString(imgData)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解码Base64图像: %w", err)
|
||
}
|
||
|
||
// 解码图像
|
||
reader := bytes.NewReader(unbased)
|
||
img, _, err := image.Decode(reader)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解码图像数据: %w", err)
|
||
}
|
||
|
||
// 创建一个可以在其上绘制的图像
|
||
bounds := img.Bounds()
|
||
rgba := image.NewRGBA(bounds)
|
||
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
|
||
|
||
// 在点击/拖动位置绘制标记
|
||
markRadius := 10
|
||
x, y := int(params.Rect.X), int(params.Rect.Y)
|
||
|
||
// 绘制红色圆圈
|
||
for i := -markRadius; i <= markRadius; i++ {
|
||
for j := -markRadius; j <= markRadius; j++ {
|
||
if i*i+j*j <= markRadius*markRadius {
|
||
if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y {
|
||
rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255})
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保存图像
|
||
outFile, err := os.Create(params.OutputPath)
|
||
if err != nil {
|
||
return fmt.Errorf("无法创建输出文件: %w", err)
|
||
}
|
||
defer outFile.Close()
|
||
|
||
// 编码为PNG并保存
|
||
if err := png.Encode(outFile, rgba); err != nil {
|
||
return fmt.Errorf("无法编码和保存图像: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// loadImage loads image and returns base64 encoded string
|
||
func loadImage(imagePath string) (base64Str string, err error) {
|
||
imageData, err := os.ReadFile(imagePath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
base64Str = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||
return
|
||
}
|