Files
httprunner/uixt/ai/planner.go
2025-04-29 20:08:22 +08:00

162 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ai
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
"image/png"
"os"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
)
type IPlanner interface {
Call(opts *PlanningOptions) (*PlanningResult, error)
}
// PlanningOptions represents the input options for planning
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"` // append to system prompt
Message *schema.Message `json:"message"`
Size types.Size `json:"size"`
}
// PlanningResult represents the result of planning
type PlanningResult struct {
NextActions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
Error string `json:"error,omitempty"`
}
// ParsedAction represents a parsed action from the VLM response
type ParsedAction struct {
ActionType ActionType `json:"actionType"`
ActionInputs map[string]interface{} `json:"actionInputs"`
Thought string `json:"thought"`
}
type ActionType string
const (
ActionTypeClick ActionType = "click"
ActionTypeTap ActionType = "tap"
ActionTypeDrag ActionType = "drag"
ActionTypeSwipe ActionType = "swipe"
ActionTypeWait ActionType = "wait"
ActionTypeFinished ActionType = "finished"
ActionTypeCallUser ActionType = "call_user"
ActionTypeType ActionType = "type"
ActionTypeScroll ActionType = "scroll"
)
const (
defaultTimeout = 30 * time.Second
)
func validatePlanningInput(opts *PlanningOptions) error {
if opts.UserInstruction == "" {
return errors.Wrap(code.LLMPrepareRequestError, "user instruction is empty")
}
if opts.Message == nil || opts.Message.Role == "" {
return errors.Wrap(code.LLMPrepareRequestError, "user message is empty")
}
if opts.Message.Role == schema.User {
// check MultiContent
if len(opts.Message.MultiContent) > 0 {
for _, content := range opts.Message.MultiContent {
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
return errors.Wrap(code.LLMPrepareRequestError, "invalid image data")
}
}
}
}
return nil
}
// SavePositionImg saves an image with position markers
func SavePositionImg(params struct {
InputImgBase64 string
Rect struct {
X float64
Y float64
}
OutputPath string
}) error {
// 解码Base64图像
imgData := params.InputImgBase64
// 如果包含了数据URL前缀去掉它
if strings.HasPrefix(imgData, "data:image/") {
parts := strings.Split(imgData, ",")
if len(parts) > 1 {
imgData = parts[1]
}
}
// 解码Base64
unbased, err := base64.StdEncoding.DecodeString(imgData)
if err != nil {
return fmt.Errorf("无法解码Base64图像: %w", err)
}
// 解码图像
reader := bytes.NewReader(unbased)
img, _, err := image.Decode(reader)
if err != nil {
return fmt.Errorf("无法解码图像数据: %w", err)
}
// 创建一个可以在其上绘制的图像
bounds := img.Bounds()
rgba := image.NewRGBA(bounds)
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
// 在点击/拖动位置绘制标记
markRadius := 10
x, y := int(params.Rect.X), int(params.Rect.Y)
// 绘制红色圆圈
for i := -markRadius; i <= markRadius; i++ {
for j := -markRadius; j <= markRadius; j++ {
if i*i+j*j <= markRadius*markRadius {
if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y {
rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255})
}
}
}
}
// 保存图像
outFile, err := os.Create(params.OutputPath)
if err != nil {
return fmt.Errorf("无法创建输出文件: %w", err)
}
defer outFile.Close()
// 编码为PNG并保存
if err := png.Encode(outFile, rgba); err != nil {
return fmt.Errorf("无法编码和保存图像: %w", err)
}
return nil
}
// maskAPIKey masks the API key
func maskAPIKey(key string) string {
if len(key) <= 8 {
return "******"
}
return key[:4] + "******" + key[len(key)-4:]
}