mirror of
https://github.com/httprunner/httprunner.git
synced 2026-06-05 15:59:33 +08:00
feat: add AIAsert
This commit is contained in:
@@ -52,23 +52,76 @@ const (
|
||||
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
|
||||
)
|
||||
|
||||
// ILLMService 定义了 LLM 服务接口,包括规划和断言功能
|
||||
type ILLMService interface {
|
||||
Call(opts *PlanningOptions) (*PlanningResult, error)
|
||||
Assert(opts *AssertOptions) (*AssertionResponse, error)
|
||||
}
|
||||
|
||||
func WithLLMService(service LLMServiceType) AIServiceOption {
|
||||
return func(opts *AIServices) {
|
||||
if service == LLMServiceTypeGPT4o {
|
||||
var err error
|
||||
opts.ILLMService, err = NewPlanner(context.Background())
|
||||
switch service {
|
||||
case LLMServiceTypeGPT4o:
|
||||
planner, err := NewPlanner(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init gpt-4o llm service failed")
|
||||
log.Error().Err(err).Msg("init gpt-4o planner failed")
|
||||
os.Exit(code.GetErrorCode(err))
|
||||
}
|
||||
}
|
||||
if service == LLMServiceTypeUITARS {
|
||||
var err error
|
||||
opts.ILLMService, err = NewUITarsPlanner(context.Background())
|
||||
|
||||
asserter, err := NewUITarsAsserter(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init ui-tars llm service failed")
|
||||
log.Error().Err(err).Msg("init ui-tars asserter failed")
|
||||
os.Exit(code.GetErrorCode(err))
|
||||
}
|
||||
|
||||
opts.ILLMService = &combinedLLMService{
|
||||
planner: planner,
|
||||
asserter: asserter,
|
||||
}
|
||||
|
||||
case LLMServiceTypeUITARS:
|
||||
planner, err := NewUITarsPlanner(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init ui-tars planner failed")
|
||||
os.Exit(code.GetErrorCode(err))
|
||||
}
|
||||
|
||||
asserter, err := NewUITarsAsserter(context.Background())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("init ui-tars asserter failed")
|
||||
os.Exit(code.GetErrorCode(err))
|
||||
}
|
||||
|
||||
opts.ILLMService = &combinedLLMService{
|
||||
planner: planner,
|
||||
asserter: asserter,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能
|
||||
type combinedLLMService struct {
|
||||
planner IPlanner // 提供规划功能
|
||||
asserter IAsserter // 提供断言功能
|
||||
}
|
||||
|
||||
// IPlanner 定义了规划功能接口
|
||||
type IPlanner interface {
|
||||
Call(opts *PlanningOptions) (*PlanningResult, error)
|
||||
}
|
||||
|
||||
// IAsserter 定义了断言功能接口
|
||||
type IAsserter interface {
|
||||
Assert(opts *AssertOptions) (*AssertionResponse, error)
|
||||
}
|
||||
|
||||
// Call 执行规划功能
|
||||
func (c *combinedLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
return c.planner.Call(opts)
|
||||
}
|
||||
|
||||
// Assert 执行断言功能
|
||||
func (c *combinedLLMService) Assert(opts *AssertOptions) (*AssertionResponse, error) {
|
||||
return c.asserter.Assert(opts)
|
||||
}
|
||||
|
||||
104
uixt/ai/asserter_test.go
Normal file
104
uixt/ai/asserter_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 创建AI服务的辅助函数
|
||||
func createAIService(t *testing.T) *AIServices {
|
||||
aiService := NewAIService(WithLLMService(LLMServiceTypeUITARS))
|
||||
require.NotNil(t, aiService)
|
||||
require.NotNil(t, aiService.ILLMService)
|
||||
return aiService
|
||||
}
|
||||
|
||||
// 测试有效断言
|
||||
func TestValidAssertions(t *testing.T) {
|
||||
aiService := createAIService(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
assertion string
|
||||
imagePath string
|
||||
expectPass bool
|
||||
}{
|
||||
{
|
||||
name: "深度思考功能已开启",
|
||||
assertion: "深度思考开关已开启(蓝色为开启状态,灰色为关闭状态)",
|
||||
imagePath: "testdata/deepseek_think_on.png",
|
||||
expectPass: true,
|
||||
},
|
||||
{
|
||||
name: "深度思考功能未开启",
|
||||
assertion: "深度思考开关未开启(蓝色为开启状态,灰色为关闭状态)",
|
||||
imagePath: "testdata/deepseek_think_off.png",
|
||||
expectPass: true,
|
||||
},
|
||||
{
|
||||
name: "网络连接功能已开启",
|
||||
assertion: "网络连接开关已开启(蓝色为开启状态,灰色为关闭状态)",
|
||||
imagePath: "testdata/deepseek_network_on.png",
|
||||
expectPass: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
imageBase64, size, err := loadImage(tc.imagePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := aiService.ILLMService.Assert(&AssertOptions{
|
||||
Assertion: tc.assertion,
|
||||
Screenshot: imageBase64,
|
||||
Size: size,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tc.expectPass, result.Pass)
|
||||
assert.NotEmpty(t, result.Thought)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 测试无效参数
|
||||
func TestInvalidParameters(t *testing.T) {
|
||||
aiService := createAIService(t)
|
||||
testCases := []struct {
|
||||
name string
|
||||
assertion string
|
||||
screenshot string
|
||||
size types.Size
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "缺少截图",
|
||||
assertion: "测试断言",
|
||||
screenshot: "",
|
||||
size: types.Size{},
|
||||
expectedError: "screenshot is required",
|
||||
},
|
||||
{
|
||||
name: "缺少断言",
|
||||
assertion: "",
|
||||
screenshot: "some-base64-data",
|
||||
size: types.Size{},
|
||||
expectedError: "assertion text is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := aiService.ILLMService.Assert(&AssertOptions{
|
||||
Assertion: tc.assertion,
|
||||
Screenshot: tc.screenshot,
|
||||
Size: tc.size,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
})
|
||||
}
|
||||
}
|
||||
257
uixt/ai/asserter_ui_tars.go
Normal file
257
uixt/ai/asserter_ui_tars.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/internal/json"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// UI-TARS assertion system prompt
|
||||
const uiTarsAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy.
|
||||
|
||||
## Output Json String Format
|
||||
` + "```" + `
|
||||
"{
|
||||
"pass": <<is a boolean value from the enum [true, false], true means the assertion is truthy>>,
|
||||
"thought": "<<is a string, give the reason why the assertion is falsy or truthy. Otherwise.>>"
|
||||
}"
|
||||
` + "```" + `
|
||||
|
||||
## Rules **MUST** follow
|
||||
- Make sure to return **only** the JSON, with **no additional** text or explanations.
|
||||
- Use Chinese in 'thought' part.
|
||||
- You **MUST** strictly follow up the **Output Json String Format**.`
|
||||
|
||||
// AssertionResponse represents the response from an AI assertion
|
||||
type AssertionResponse struct {
|
||||
Pass bool `json:"pass"`
|
||||
Thought string `json:"thought"`
|
||||
}
|
||||
|
||||
// UITarsAsserter handles assertion using UI-TARS VLM
|
||||
type UITarsAsserter struct {
|
||||
ctx context.Context
|
||||
model *ark.ChatModel
|
||||
config *ark.ChatModelConfig
|
||||
systemPrompt string
|
||||
history []*schema.Message // conversation history
|
||||
}
|
||||
|
||||
// NewUITarsAsserter creates a new UITarsAsserter instance
|
||||
func NewUITarsAsserter(ctx context.Context) (*UITarsAsserter, error) {
|
||||
config, err := GetArkModelConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chatModel, err := ark.NewChatModel(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UITarsAsserter{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
model: chatModel,
|
||||
systemPrompt: uiTarsAssertionPrompt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AssertOptions represents the input options for assertion
|
||||
type AssertOptions struct {
|
||||
Assertion string `json:"assertion"` // The assertion text to verify
|
||||
Screenshot string `json:"screenshot"` // Base64 encoded screenshot
|
||||
Size types.Size `json:"size"` // Screen dimensions
|
||||
}
|
||||
|
||||
// Assert performs the assertion check on the screenshot
|
||||
func (a *UITarsAsserter) Assert(opts *AssertOptions) (*AssertionResponse, error) {
|
||||
// Validate input parameters
|
||||
if opts.Assertion == "" {
|
||||
return nil, errors.New("assertion text is required")
|
||||
}
|
||||
if opts.Screenshot == "" {
|
||||
return nil, errors.New("screenshot is required")
|
||||
}
|
||||
|
||||
// Reset history for each new assertion
|
||||
a.history = []*schema.Message{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: a.systemPrompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Create user message with screenshot and assertion
|
||||
userMsg := &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: opts.Screenshot,
|
||||
Detail: schema.ImageURLDetailAuto,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: fmt.Sprintf(`
|
||||
Here is the assertion. Please tell whether it is truthy according to the screenshot.
|
||||
=====================================
|
||||
%s
|
||||
=====================================
|
||||
`, opts.Assertion),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Append user message to history
|
||||
appendConversationHistory(&a.history, userMsg)
|
||||
|
||||
// Call model service, generate response
|
||||
logRequest(a.history)
|
||||
startTime := time.Now()
|
||||
resp, err := a.model.Generate(a.ctx, a.history)
|
||||
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
||||
Str("model", a.config.Model).Msg("call model service for assertion")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request model service failed: %w", err)
|
||||
}
|
||||
logResponse(resp)
|
||||
|
||||
// Parse result
|
||||
result, err := parseAssertionResult(resp.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse assertion result failed")
|
||||
}
|
||||
|
||||
// Append assistant message to history
|
||||
appendConversationHistory(&a.history, &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: resp.Content,
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseAssertionResult 解析模型返回的JSON响应
|
||||
func parseAssertionResult(content string) (*AssertionResponse, error) {
|
||||
// 1. 从响应中提取JSON内容
|
||||
jsonContent := extractJSON(content)
|
||||
if jsonContent == "" {
|
||||
return nil, errors.New("could not extract JSON from response")
|
||||
}
|
||||
|
||||
// 2. 预处理和标准解析尝试
|
||||
jsonContent = prepareJSON(jsonContent)
|
||||
var result AssertionResponse
|
||||
if err := json.Unmarshal([]byte(jsonContent), &result); err == nil {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// 3. 备用:正则表达式解析
|
||||
if pass, thought := extractWithRegex(jsonContent); thought != "" {
|
||||
return &AssertionResponse{Pass: pass, Thought: thought}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to parse assertion result")
|
||||
}
|
||||
|
||||
// prepareJSON 预处理JSON字符串,修复常见问题
|
||||
func prepareJSON(jsonStr string) string {
|
||||
// 1. 去除可能的外层引号
|
||||
jsonStr = strings.TrimSpace(jsonStr)
|
||||
if strings.HasPrefix(jsonStr, "\"") && strings.HasSuffix(jsonStr, "\"") {
|
||||
jsonStr = jsonStr[1 : len(jsonStr)-1]
|
||||
}
|
||||
|
||||
// 2. 转义thought内容中的引号
|
||||
thoughtRegex := regexp.MustCompile(`"thought":\s*"([^"]*)"`)
|
||||
matches := thoughtRegex.FindStringSubmatch(jsonStr)
|
||||
if len(matches) > 1 {
|
||||
thoughtValue := matches[1]
|
||||
fixedThought := strings.ReplaceAll(thoughtValue, "\"", "\\\"")
|
||||
jsonStr = strings.Replace(jsonStr, matches[0], fmt.Sprintf(`"thought": "%s"`, fixedThought), 1)
|
||||
}
|
||||
|
||||
// 3. 处理换行和特殊字符
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\n", "\\n")
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\r", "\\r")
|
||||
jsonStr = strings.ReplaceAll(jsonStr, "\t", "\\t")
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// extractWithRegex 使用正则表达式提取pass和thought值
|
||||
func extractWithRegex(jsonStr string) (pass bool, thought string) {
|
||||
// 提取pass值
|
||||
passRegex := regexp.MustCompile(`"pass":\s*(true|false)`)
|
||||
passMatches := passRegex.FindStringSubmatch(jsonStr)
|
||||
|
||||
// 提取thought值
|
||||
thoughtRegex := regexp.MustCompile(`"thought":\s*"([^"]*(?:"[^"]*)*)"`)
|
||||
thoughtMatches := thoughtRegex.FindStringSubmatch(jsonStr)
|
||||
|
||||
if len(passMatches) > 1 && len(thoughtMatches) > 1 {
|
||||
// 处理提取的值
|
||||
pass = passMatches[1] == "true"
|
||||
thought = strings.ReplaceAll(thoughtMatches[1], "\\\"", "\"")
|
||||
thought = strings.ReplaceAll(thought, "\\\\", "\\")
|
||||
return pass, thought
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// extractJSON extracts JSON content from a string that might contain markdown or other formatting
|
||||
func extractJSON(content string) string {
|
||||
// Try to extract JSON directly
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// If the content is already a valid JSON, return it
|
||||
if strings.HasPrefix(content, "{") && strings.HasSuffix(content, "}") {
|
||||
return content
|
||||
}
|
||||
|
||||
// Check for markdown code blocks with more flexible pattern
|
||||
jsonRegex := regexp.MustCompile(`(?:json)?\s*({[\s\S]*?})\s*`)
|
||||
matches := jsonRegex.FindStringSubmatch(content)
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
// Try a more robust approach for JSON with Chinese characters
|
||||
// First look for the outermost pair of curly braces
|
||||
startIdx := strings.Index(content, "{")
|
||||
if startIdx >= 0 {
|
||||
depth := 1
|
||||
for i := startIdx + 1; i < len(content); i++ {
|
||||
if content[i] == '{' {
|
||||
depth++
|
||||
} else if content[i] == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
// Found the closing brace
|
||||
return content[startIdx : i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to regex approach
|
||||
braceRegex := regexp.MustCompile(`{[\s\S]*?}`)
|
||||
matches = braceRegex.FindStringSubmatch(content)
|
||||
if len(matches) > 0 {
|
||||
return strings.TrimSpace(matches[0])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -17,10 +17,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ILLMService interface {
|
||||
Call(opts *PlanningOptions) (*PlanningResult, error)
|
||||
}
|
||||
|
||||
// PlanningOptions represents the input options for planning
|
||||
type PlanningOptions struct {
|
||||
UserInstruction string `json:"user_instruction"` // append to system prompt
|
||||
@@ -132,7 +128,7 @@ func logResponse(resp *schema.Message) {
|
||||
}
|
||||
|
||||
// appendConversationHistory adds a message to the conversation history
|
||||
func appendConversationHistory(history []*schema.Message, msg *schema.Message) {
|
||||
func appendConversationHistory(history *[]*schema.Message, msg *schema.Message) {
|
||||
// for user image message:
|
||||
// - keep at most 4 user image messages
|
||||
// - delete the oldest user image message when the limit is reached
|
||||
@@ -142,7 +138,7 @@ func appendConversationHistory(history []*schema.Message, msg *schema.Message) {
|
||||
firstUserImgIndex := -1
|
||||
|
||||
// calculate the number of user messages and find the index of the first user message
|
||||
for i, item := range history {
|
||||
for i, item := range *history {
|
||||
if item.Role == schema.User {
|
||||
userImgCount++
|
||||
if firstUserImgIndex == -1 {
|
||||
@@ -154,28 +150,28 @@ func appendConversationHistory(history []*schema.Message, msg *schema.Message) {
|
||||
// if there are already 4 user messages, delete the first one before adding the new message
|
||||
if userImgCount >= 4 && firstUserImgIndex >= 0 {
|
||||
// delete the first user message
|
||||
history = append(
|
||||
history[:firstUserImgIndex],
|
||||
history[firstUserImgIndex+1:]...,
|
||||
*history = append(
|
||||
(*history)[:firstUserImgIndex],
|
||||
(*history)[firstUserImgIndex+1:]...,
|
||||
)
|
||||
}
|
||||
// add the new user message to the history
|
||||
history = append(history, msg)
|
||||
*history = append(*history, msg)
|
||||
}
|
||||
|
||||
// for assistant message:
|
||||
// - keep at most the last 10 assistant messages
|
||||
if msg.Role == schema.Assistant {
|
||||
// add the new assistant message to the history
|
||||
history = append(history, msg)
|
||||
*history = append(*history, msg)
|
||||
|
||||
// if there are more than 10 assistant messages, remove the oldest ones
|
||||
assistantMsgCount := 0
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == schema.Assistant {
|
||||
for i := len(*history) - 1; i >= 0; i-- {
|
||||
if (*history)[i].Role == schema.Assistant {
|
||||
assistantMsgCount++
|
||||
if assistantMsgCount > 10 {
|
||||
history = append(history[:i], history[i+1:]...)
|
||||
*history = append((*history)[:i], (*history)[i+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,37 +245,6 @@ func SavePositionImg(params struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadImage loads image and returns base64 encoded string
|
||||
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
|
||||
// Read the image file
|
||||
imageFile, err := os.Open(imagePath)
|
||||
if err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
|
||||
}
|
||||
defer imageFile.Close()
|
||||
|
||||
// Decode the image to get its resolution
|
||||
imageData, format, err := image.Decode(imageFile)
|
||||
if err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Get the resolution of the image
|
||||
width := imageData.Bounds().Dx()
|
||||
height := imageData.Bounds().Dy()
|
||||
size = types.Size{Width: width, Height: height}
|
||||
|
||||
// Convert image to base64
|
||||
buf := new(bytes.Buffer)
|
||||
if err := png.Encode(buf, imageData); err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
|
||||
}
|
||||
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
|
||||
base64.StdEncoding.EncodeToString(buf.Bytes()))
|
||||
|
||||
return base64Str, size, nil
|
||||
}
|
||||
|
||||
// maskAPIKey masks the API key
|
||||
func maskAPIKey(key string) string {
|
||||
if len(key) <= 8 {
|
||||
|
||||
@@ -132,7 +132,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
}
|
||||
}
|
||||
// append user image message
|
||||
appendConversationHistory(p.history, opts.Message)
|
||||
appendConversationHistory(&p.history, opts.Message)
|
||||
|
||||
// call model service, generate response
|
||||
logRequest(p.history)
|
||||
@@ -152,7 +152,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
}
|
||||
|
||||
// append assistant message
|
||||
appendConversationHistory(p.history, &schema.Message{
|
||||
appendConversationHistory(&p.history, &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: result.ActionSummary,
|
||||
})
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -408,3 +414,42 @@ func TestLoadImage(t *testing.T) {
|
||||
assert.Greater(t, jpegSize.Width, 0)
|
||||
assert.Greater(t, jpegSize.Height, 0)
|
||||
}
|
||||
|
||||
// loadImage loads image and returns base64 encoded string
|
||||
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
|
||||
// Read the image file
|
||||
imageFile, err := os.Open(imagePath)
|
||||
if err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
|
||||
}
|
||||
defer imageFile.Close()
|
||||
|
||||
// Decode the image to get its resolution
|
||||
imageData, format, err := image.Decode(imageFile)
|
||||
if err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Get the resolution of the image
|
||||
width := imageData.Bounds().Dx()
|
||||
height := imageData.Bounds().Dy()
|
||||
size = types.Size{Width: width, Height: height}
|
||||
|
||||
// Convert image to base64
|
||||
buf := new(bytes.Buffer)
|
||||
// 根据图像格式选择正确的编码器
|
||||
if format == "jpeg" || format == "jpg" {
|
||||
if err := jpeg.Encode(buf, imageData, nil); err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 默认使用 PNG 编码
|
||||
if err := png.Encode(buf, imageData); err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
|
||||
}
|
||||
}
|
||||
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
|
||||
base64.StdEncoding.EncodeToString(buf.Bytes()))
|
||||
|
||||
return base64Str, size, nil
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
}
|
||||
}
|
||||
// append user image message
|
||||
appendConversationHistory(p.history, opts.Message)
|
||||
appendConversationHistory(&p.history, opts.Message)
|
||||
|
||||
// call model service, generate response
|
||||
logRequest(p.history)
|
||||
@@ -165,7 +165,7 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
}
|
||||
|
||||
// append assistant message
|
||||
appendConversationHistory(p.history, &schema.Message{
|
||||
appendConversationHistory(&p.history, &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: result.ActionSummary,
|
||||
})
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/httprunner/httprunner/v5/internal/config"
|
||||
"github.com/httprunner/httprunner/v5/uixt/ai"
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -59,7 +60,40 @@ func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (string,
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (dExt *XTDriver) AIAssert(text string, opts ...option.ActionOption) error {
|
||||
func (dExt *XTDriver) AIAssert(assertion string, opts ...option.ActionOption) error {
|
||||
if dExt.LLMService == nil {
|
||||
return errors.New("LLM service is not initialized")
|
||||
}
|
||||
|
||||
compressedBufSource, err := dExt.GetScreenShotBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert buffer to base64 string
|
||||
screenShotBase64 := "data:image/jpeg;base64," +
|
||||
base64.StdEncoding.EncodeToString(compressedBufSource.Bytes())
|
||||
|
||||
// get window size
|
||||
size, err := dExt.IDriver.WindowSize()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get window size for AI assertion failed")
|
||||
}
|
||||
|
||||
// execute assertion
|
||||
result, err := dExt.LLMService.Assert(&ai.AssertOptions{
|
||||
Assertion: assertion,
|
||||
Screenshot: screenShotBase64,
|
||||
Size: types.Size{Width: size.Width, Height: size.Height},
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "AI assertion failed")
|
||||
}
|
||||
|
||||
if !result.Pass {
|
||||
return errors.New(result.Thought)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +127,9 @@ func TestDriverExt_TapByLLM(t *testing.T) {
|
||||
driver := setupDriverExt(t)
|
||||
err := driver.AIAction("点击第一个帖子的作者头像")
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = driver.AIAssert("当前在个人介绍页")
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestDriverExt_StartToGoal(t *testing.T) {
|
||||
|
||||
@@ -185,8 +185,8 @@ func (dExt *XTDriver) DoValidation(check, assert, expected string, message ...st
|
||||
switch check {
|
||||
case SelectorOCR:
|
||||
err = dExt.assertOCR(expected, assert)
|
||||
// case SelectorAI:
|
||||
// // TODO
|
||||
case SelectorAI:
|
||||
err = dExt.AIAssert(assert)
|
||||
case SelectorForegroundApp:
|
||||
err = dExt.assertForegroundApp(expected, assert)
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user