feat: assert with openai model

This commit is contained in:
lilong.129
2025-04-29 22:03:11 +08:00
parent f6e421fc34
commit 429bfe3986
6 changed files with 58 additions and 57 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2504292008
v5.0.0-beta-2504292203

View File

@@ -51,7 +51,6 @@ const (
LLMServiceTypeGPT4o LLMServiceType = "gpt-4o"
LLMServiceTypeGPT4Vision LLMServiceType = "gpt-4-vision"
LLMServiceTypeQwenVL LLMServiceType = "qwen-vl"
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
)
// ILLMService 定义了 LLM 服务接口,包括规划和断言功能

View File

@@ -35,20 +35,13 @@ func GetArkModelConfig() (*ark.ChatModelConfig, error) {
timeout := defaultTimeout
// https://www.volcengine.com/docs/82379/1494384?redirect=1
temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。
// topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。
// maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。
// frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。
temperature := float32(0.01)
modelConfig := &ark.ChatModelConfig{
BaseURL: arkBaseURL,
APIKey: arkAPIKey,
Model: modelName,
Timeout: &timeout,
Temperature: &temperature,
// TopP: &topP,
// MaxTokens: &maxTokens,
// FrequencyPenalty: &frequencyPenalty,
}
// log config info

View File

@@ -4,8 +4,6 @@ import (
"os"
"github.com/cloudwego/eino-ext/components/model/openai"
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
"github.com/getkin/kin-openapi/openapi3gen"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/config"
"github.com/pkg/errors"
@@ -40,32 +38,13 @@ func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
"env %s missed", EnvModelName)
}
type OutputFormat struct {
Thought string `json:"thought"`
Action string `json:"action"`
Error string `json:"error,omitempty"`
}
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
if err != nil {
return nil, err
}
temperature := float32(0.01)
modelConfig := &openai.ChatModelConfig{
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
// set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
ResponseFormat: &openai2.ChatCompletionResponseFormat{
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
Name: "thought_and_action",
Description: "data that describes planning thought and action",
Schema: outputFormatSchema.Value,
Strict: false,
},
},
BaseURL: openaiBaseURL,
APIKey: openaiAPIKey,
Model: modelName,
Timeout: defaultTimeout,
Temperature: &temperature,
}
// log config info

View File

@@ -9,8 +9,10 @@ import (
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/openai"
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3gen"
"github.com/httprunner/httprunner/v5/code"
"github.com/httprunner/httprunner/v5/internal/json"
"github.com/httprunner/httprunner/v5/uixt/types"
@@ -50,7 +52,7 @@ func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, erro
asserter := &Asserter{
ctx: ctx,
modelType: modelType,
systemPrompt: getAssertionSystemPrompt(modelType),
systemPrompt: defaultAssertionPrompt,
}
switch modelType {
@@ -59,19 +61,56 @@ func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, erro
if err != nil {
return nil, err
}
asserter.systemPrompt += "\n\n" + uiTarsAssertionResponseFormat
asserter.model, err = ark.NewChatModel(ctx, config)
if err != nil {
return nil, err
}
case LLMServiceTypeGPT4Vision, LLMServiceTypeGPT4o:
config, err := GetOpenAIModelConfig()
if err != nil {
return nil, err
}
// define output format
type OutputFormat struct {
Thought string `json:"thought"`
Pass bool `json:"pass"`
Error string `json:"error,omitempty"`
}
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
if err != nil {
return nil, err
}
// set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
config.ResponseFormat = &openai2.ChatCompletionResponseFormat{
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
Name: "assertion_result",
Description: "data that describes assertion result",
Schema: outputFormatSchema.Value,
Strict: false,
},
}
asserter.model, err = openai.NewChatModel(ctx, config)
if err != nil {
return nil, err
}
case LLMServiceTypeQwenVL:
config, err := GetOpenAIModelConfig()
if err != nil {
return nil, err
}
asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat
asserter.model, err = openai.NewChatModel(ctx, config)
if err != nil {
return nil, err
}
default:
return nil, errors.New("not supported model type for asserter")
}
@@ -79,14 +118,6 @@ func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, erro
return asserter, nil
}
// getAssertionSystemPrompt returns the appropriate system prompt for the given model type
func getAssertionSystemPrompt(modelType LLMServiceType) string {
if modelType == LLMServiceTypeUITARS {
return defaultAssertionPrompt + "\n\n" + uiTarsAssertionResponseFormat
}
return defaultAssertionPrompt + "\n\n" + defaultAssertionResponseJsonFormat
}
// Assert performs the assertion check on the screenshot
func (a *Asserter) Assert(opts *AssertOptions) (*AssertionResponse, error) {
// Validate input parameters

View File

@@ -1,6 +1,7 @@
package ai
import (
"context"
"testing"
"github.com/httprunner/httprunner/v5/uixt/types"
@@ -8,16 +9,15 @@ import (
"github.com/stretchr/testify/require"
)
func createAIService(t *testing.T) *AIServices {
aiService := NewAIService(WithLLMService(LLMServiceTypeUITARS))
require.NotNil(t, aiService)
require.NotNil(t, aiService.ILLMService)
return aiService
func createAsserter(t *testing.T) *Asserter {
asserter, err := NewAsserter(context.Background(), LLMServiceTypeUITARS)
require.NoError(t, err)
return asserter
}
// 测试有效断言
func TestValidAssertions(t *testing.T) {
aiService := createAIService(t)
asserter := createAsserter(t)
testCases := []struct {
name string
@@ -33,7 +33,7 @@ func TestValidAssertions(t *testing.T) {
},
{
name: "深度思考功能未开启",
assertion: "输入框下方的「深度思考」文字是灰色的",
assertion: "输入框下方的「深度思考」文字不是蓝色的",
imagePath: "testdata/deepseek_think_off.png",
expectPass: true,
},
@@ -50,7 +50,7 @@ func TestValidAssertions(t *testing.T) {
imageBase64, size, err := loadImage(tc.imagePath)
require.NoError(t, err)
result, err := aiService.ILLMService.Assert(&AssertOptions{
result, err := asserter.Assert(&AssertOptions{
Assertion: tc.assertion,
Screenshot: imageBase64,
Size: size,
@@ -58,14 +58,13 @@ func TestValidAssertions(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, tc.expectPass, result.Pass)
assert.NotEmpty(t, result.Thought)
})
}
}
// 测试无效参数
func TestInvalidParameters(t *testing.T) {
aiService := createAIService(t)
asserter := createAsserter(t)
testCases := []struct {
name string
assertion string
@@ -91,7 +90,7 @@ func TestInvalidParameters(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := aiService.ILLMService.Assert(&AssertOptions{
_, err := asserter.Assert(&AssertOptions{
Assertion: tc.assertion,
Screenshot: tc.screenshot,
Size: tc.size,