mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-12 02:21:29 +08:00
162 lines
4.0 KiB
Go
162 lines
4.0 KiB
Go
package ai
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/base64"
|
||
"fmt"
|
||
"image"
|
||
"image/color"
|
||
"image/draw"
|
||
"image/png"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/httprunner/httprunner/v5/code"
|
||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||
"github.com/pkg/errors"
|
||
)
|
||
|
||
type IPlanner interface {
|
||
Call(opts *PlanningOptions) (*PlanningResult, error)
|
||
}
|
||
|
||
// PlanningOptions represents the input options for planning
|
||
type PlanningOptions struct {
|
||
UserInstruction string `json:"user_instruction"` // append to system prompt
|
||
Message *schema.Message `json:"message"`
|
||
Size types.Size `json:"size"`
|
||
}
|
||
|
||
// PlanningResult represents the result of planning
|
||
type PlanningResult struct {
|
||
NextActions []ParsedAction `json:"actions"`
|
||
ActionSummary string `json:"summary"`
|
||
Error string `json:"error,omitempty"`
|
||
}
|
||
|
||
// ParsedAction represents a parsed action from the VLM response
|
||
type ParsedAction struct {
|
||
ActionType ActionType `json:"actionType"`
|
||
ActionInputs map[string]interface{} `json:"actionInputs"`
|
||
Thought string `json:"thought"`
|
||
}
|
||
|
||
type ActionType string
|
||
|
||
const (
|
||
ActionTypeClick ActionType = "click"
|
||
ActionTypeTap ActionType = "tap"
|
||
ActionTypeDrag ActionType = "drag"
|
||
ActionTypeSwipe ActionType = "swipe"
|
||
ActionTypeWait ActionType = "wait"
|
||
ActionTypeFinished ActionType = "finished"
|
||
ActionTypeCallUser ActionType = "call_user"
|
||
ActionTypeType ActionType = "type"
|
||
ActionTypeScroll ActionType = "scroll"
|
||
)
|
||
|
||
const (
|
||
defaultTimeout = 30 * time.Second
|
||
)
|
||
|
||
func validatePlanningInput(opts *PlanningOptions) error {
|
||
if opts.UserInstruction == "" {
|
||
return errors.Wrap(code.LLMPrepareRequestError, "user instruction is empty")
|
||
}
|
||
|
||
if opts.Message == nil || opts.Message.Role == "" {
|
||
return errors.Wrap(code.LLMPrepareRequestError, "user message is empty")
|
||
}
|
||
|
||
if opts.Message.Role == schema.User {
|
||
// check MultiContent
|
||
if len(opts.Message.MultiContent) > 0 {
|
||
for _, content := range opts.Message.MultiContent {
|
||
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
|
||
return errors.Wrap(code.LLMPrepareRequestError, "invalid image data")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// SavePositionImg saves an image with position markers
|
||
func SavePositionImg(params struct {
|
||
InputImgBase64 string
|
||
Rect struct {
|
||
X float64
|
||
Y float64
|
||
}
|
||
OutputPath string
|
||
}) error {
|
||
// 解码Base64图像
|
||
imgData := params.InputImgBase64
|
||
// 如果包含了数据URL前缀,去掉它
|
||
if strings.HasPrefix(imgData, "data:image/") {
|
||
parts := strings.Split(imgData, ",")
|
||
if len(parts) > 1 {
|
||
imgData = parts[1]
|
||
}
|
||
}
|
||
|
||
// 解码Base64
|
||
unbased, err := base64.StdEncoding.DecodeString(imgData)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解码Base64图像: %w", err)
|
||
}
|
||
|
||
// 解码图像
|
||
reader := bytes.NewReader(unbased)
|
||
img, _, err := image.Decode(reader)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解码图像数据: %w", err)
|
||
}
|
||
|
||
// 创建一个可以在其上绘制的图像
|
||
bounds := img.Bounds()
|
||
rgba := image.NewRGBA(bounds)
|
||
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
|
||
|
||
// 在点击/拖动位置绘制标记
|
||
markRadius := 10
|
||
x, y := int(params.Rect.X), int(params.Rect.Y)
|
||
|
||
// 绘制红色圆圈
|
||
for i := -markRadius; i <= markRadius; i++ {
|
||
for j := -markRadius; j <= markRadius; j++ {
|
||
if i*i+j*j <= markRadius*markRadius {
|
||
if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y {
|
||
rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255})
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保存图像
|
||
outFile, err := os.Create(params.OutputPath)
|
||
if err != nil {
|
||
return fmt.Errorf("无法创建输出文件: %w", err)
|
||
}
|
||
defer outFile.Close()
|
||
|
||
// 编码为PNG并保存
|
||
if err := png.Encode(outFile, rgba); err != nil {
|
||
return fmt.Errorf("无法编码和保存图像: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// maskAPIKey masks the API key
|
||
func maskAPIKey(key string) string {
|
||
if len(key) <= 8 {
|
||
return "******"
|
||
}
|
||
|
||
return key[:4] + "******" + key[len(key)-4:]
|
||
}
|