Files
httprunner/uixt/ai/planner.go
lilong.129 ebeae596a7 stash
2025-04-21 14:39:37 +08:00

339 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ai
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
_ "image/jpeg"
"image/png"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
// Error types
var (
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
ErrInvalidImageData = fmt.Errorf("invalid image data")
)
func NewPlanner(ctx context.Context) (*Planner, error) {
config, err := GetModelConfig()
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
}
model, err := openai.NewChatModel(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
}
parser := NewActionParser(1000)
return &Planner{
ctx: ctx,
config: config,
model: model,
systemPrompt: uiTarsPlanningPrompt,
parser: parser,
}, nil
}
type Planner struct {
ctx context.Context
model model.ChatModel
config *openai.ChatModelConfig
systemPrompt string
parser *ActionParser
history []*schema.Message // conversation history
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
}
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
}
}
// append user image message
p.appendConversationHistory(opts.Message)
// call model service, generate response
logRequest(p.history)
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", p.config.Model).Msg("call model service")
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
}
logResponse(resp)
// parse result
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
}
// append assistant message
p.appendConversationHistory(&schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})
return result, nil
}
func validateInput(opts *PlanningOptions) error {
if opts.UserInstruction == "" {
return ErrEmptyInstruction
}
if opts.Message == nil {
return ErrNoConversationHistory
}
if opts.Message.Role == schema.User {
// check MultiContent
if len(opts.Message.MultiContent) > 0 {
for _, content := range opts.Message.MultiContent {
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
return ErrInvalidImageData
}
}
}
}
return nil
}
func logRequest(messages []*schema.Message) {
msgs := make([]*schema.Message, 0, len(messages))
for _, message := range messages {
msg := &schema.Message{
Role: message.Role,
}
if message.Content != "" {
msg.Content = message.Content
} else if len(message.MultiContent) > 0 {
for _, mc := range message.MultiContent {
switch mc.Type {
case schema.ChatMessagePartTypeImageURL:
// Create a copy of the ImageURL to avoid modifying the original message
imageURLCopy := *mc.ImageURL
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
imageURLCopy.URL = "<data:image/base64...>"
}
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
Type: mc.Type,
ImageURL: &imageURLCopy,
})
}
}
}
msgs = append(msgs, msg)
}
log.Debug().Interface("messages", msgs).Msg("log request messages")
}
func logResponse(resp *schema.Message) {
logger := log.Info().Str("role", string(resp.Role)).
Str("content", resp.Content)
if resp.ResponseMeta != nil {
logger = logger.Interface("response_meta", resp.ResponseMeta)
}
if resp.Extra != nil {
logger = logger.Interface("extra", resp.Extra)
}
logger.Msg("log response message")
}
// appendConversationHistory adds a message to the conversation history
func (p *Planner) appendConversationHistory(msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
if msg.Role == schema.User {
// get all existing user messages
userImgCount := 0
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range p.history {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
firstUserImgIndex = i
}
}
}
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
p.history = append(
p.history[:firstUserImgIndex],
p.history[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
p.history = append(p.history, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
p.history = append(p.history, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(p.history) - 1; i >= 0; i-- {
if p.history[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
p.history = append(p.history[:i], p.history[i+1:]...)
}
}
}
}
}
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
// parse response
parseActions, err := p.parser.Parse(msg.Content)
if err != nil {
return nil, fmt.Errorf("failed to parse actions: %w", err)
}
// process response
result, err := processVLMResponse(parseActions, size)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
return result, nil
}
// SavePositionImg saves an image with position markers
func SavePositionImg(params struct {
InputImgBase64 string
Rect struct {
X float64
Y float64
}
OutputPath string
}) error {
// 解码Base64图像
imgData := params.InputImgBase64
// 如果包含了数据URL前缀去掉它
if strings.HasPrefix(imgData, "data:image/") {
parts := strings.Split(imgData, ",")
if len(parts) > 1 {
imgData = parts[1]
}
}
// 解码Base64
unbased, err := base64.StdEncoding.DecodeString(imgData)
if err != nil {
return fmt.Errorf("无法解码Base64图像: %w", err)
}
// 解码图像
reader := bytes.NewReader(unbased)
img, _, err := image.Decode(reader)
if err != nil {
return fmt.Errorf("无法解码图像数据: %w", err)
}
// 创建一个可以在其上绘制的图像
bounds := img.Bounds()
rgba := image.NewRGBA(bounds)
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
// 在点击/拖动位置绘制标记
markRadius := 10
x, y := int(params.Rect.X), int(params.Rect.Y)
// 绘制红色圆圈
for i := -markRadius; i <= markRadius; i++ {
for j := -markRadius; j <= markRadius; j++ {
if i*i+j*j <= markRadius*markRadius {
if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y {
rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255})
}
}
}
}
// 保存图像
outFile, err := os.Create(params.OutputPath)
if err != nil {
return fmt.Errorf("无法创建输出文件: %w", err)
}
defer outFile.Close()
// 编码为PNG并保存
if err := png.Encode(outFile, rgba); err != nil {
return fmt.Errorf("无法编码和保存图像: %w", err)
}
return nil
}
// loadImage loads image and returns base64 encoded string
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
// Read the image file
imageFile, err := os.Open(imagePath)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
}
defer imageFile.Close()
// Decode the image to get its resolution
imageData, format, err := image.Decode(imageFile)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
}
// Get the resolution of the image
width := imageData.Bounds().Dx()
height := imageData.Bounds().Dy()
size = types.Size{Width: width, Height: height}
// Convert image to base64
buf := new(bytes.Buffer)
if err := png.Encode(buf, imageData); err != nil {
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
}
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
base64.StdEncoding.EncodeToString(buf.Bytes()))
return base64Str, size, nil
}