mirror of
https://github.com/httprunner/httprunner.git
synced 2026-06-09 01:39:39 +08:00
refactor: ai asserter
This commit is contained in:
@@ -1 +1 @@
|
|||||||
v5.0.0-beta-2504291210
|
v5.0.0-beta-2504292008
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ type LLMServiceType string
|
|||||||
const (
|
const (
|
||||||
LLMServiceTypeUITARS LLMServiceType = "ui-tars"
|
LLMServiceTypeUITARS LLMServiceType = "ui-tars"
|
||||||
LLMServiceTypeGPT4o LLMServiceType = "gpt-4o"
|
LLMServiceTypeGPT4o LLMServiceType = "gpt-4o"
|
||||||
|
LLMServiceTypeGPT4Vision LLMServiceType = "gpt-4-vision"
|
||||||
|
LLMServiceTypeQwenVL LLMServiceType = "qwen-vl"
|
||||||
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
|
LLMServiceTypeDeepSeekV3 LLMServiceType = "deepseek-v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,45 +60,33 @@ type ILLMService interface {
|
|||||||
Assert(opts *AssertOptions) (*AssertionResponse, error)
|
Assert(opts *AssertOptions) (*AssertionResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithLLMService(service LLMServiceType) AIServiceOption {
|
func WithLLMService(modelType LLMServiceType) AIServiceOption {
|
||||||
return func(opts *AIServices) {
|
return func(opts *AIServices) {
|
||||||
switch service {
|
// init planner
|
||||||
|
var planner IPlanner
|
||||||
|
var err error
|
||||||
|
switch modelType {
|
||||||
case LLMServiceTypeGPT4o:
|
case LLMServiceTypeGPT4o:
|
||||||
// TODO: implement gpt-4o planner and asserter
|
// TODO: implement gpt-4o planner and asserter
|
||||||
planner, err := NewPlanner(context.Background())
|
planner, err = NewPlanner(context.Background())
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("init gpt-4o planner failed")
|
|
||||||
os.Exit(code.GetErrorCode(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
asserter, err := NewUITarsAsserter(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("init ui-tars asserter failed")
|
|
||||||
os.Exit(code.GetErrorCode(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
opts.ILLMService = &combinedLLMService{
|
|
||||||
planner: planner,
|
|
||||||
asserter: asserter,
|
|
||||||
}
|
|
||||||
|
|
||||||
case LLMServiceTypeUITARS:
|
case LLMServiceTypeUITARS:
|
||||||
planner, err := NewUITarsPlanner(context.Background())
|
planner, err = NewUITarsPlanner(context.Background())
|
||||||
if err != nil {
|
}
|
||||||
log.Error().Err(err).Msg("init ui-tars planner failed")
|
if err != nil {
|
||||||
os.Exit(code.GetErrorCode(err))
|
log.Error().Err(err).Msgf("init %s planner failed", modelType)
|
||||||
}
|
os.Exit(code.GetErrorCode(err))
|
||||||
|
}
|
||||||
|
|
||||||
asserter, err := NewUITarsAsserter(context.Background())
|
// init asserter
|
||||||
if err != nil {
|
asserter, err := NewAsserter(context.Background(), modelType)
|
||||||
log.Error().Err(err).Msg("init ui-tars asserter failed")
|
if err != nil {
|
||||||
os.Exit(code.GetErrorCode(err))
|
log.Error().Err(err).Msgf("init %s asserter failed", modelType)
|
||||||
}
|
os.Exit(code.GetErrorCode(err))
|
||||||
|
}
|
||||||
|
|
||||||
opts.ILLMService = &combinedLLMService{
|
opts.ILLMService = &combinedLLMService{
|
||||||
planner: planner,
|
planner: planner,
|
||||||
asserter: asserter,
|
asserter: asserter,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
62
uixt/ai/ai_ark.go
Normal file
62
uixt/ai/ai_ark.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||||
|
"github.com/httprunner/httprunner/v5/code"
|
||||||
|
"github.com/httprunner/httprunner/v5/internal/config"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EnvArkBaseURL = "ARK_BASE_URL"
|
||||||
|
EnvArkAPIKey = "ARK_API_KEY"
|
||||||
|
EnvArkModelID = "ARK_MODEL_ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetArkModelConfig() (*ark.ChatModelConfig, error) {
|
||||||
|
if err := config.LoadEnv(); err != nil {
|
||||||
|
return nil, errors.Wrap(code.LoadEnvError, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
arkBaseURL := os.Getenv(EnvArkBaseURL)
|
||||||
|
arkAPIKey := os.Getenv(EnvArkAPIKey)
|
||||||
|
if arkAPIKey == "" {
|
||||||
|
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
||||||
|
"env %s missed", EnvArkAPIKey)
|
||||||
|
}
|
||||||
|
modelName := os.Getenv(EnvArkModelID)
|
||||||
|
if modelName == "" {
|
||||||
|
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
||||||
|
"env %s missed", EnvArkModelID)
|
||||||
|
}
|
||||||
|
timeout := defaultTimeout
|
||||||
|
|
||||||
|
// https://www.volcengine.com/docs/82379/1494384?redirect=1
|
||||||
|
temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。
|
||||||
|
// topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。
|
||||||
|
// maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。
|
||||||
|
// frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。
|
||||||
|
|
||||||
|
modelConfig := &ark.ChatModelConfig{
|
||||||
|
BaseURL: arkBaseURL,
|
||||||
|
APIKey: arkAPIKey,
|
||||||
|
Model: modelName,
|
||||||
|
Timeout: &timeout,
|
||||||
|
Temperature: &temperature,
|
||||||
|
// TopP: &topP,
|
||||||
|
// MaxTokens: &maxTokens,
|
||||||
|
// FrequencyPenalty: &frequencyPenalty,
|
||||||
|
}
|
||||||
|
|
||||||
|
// log config info
|
||||||
|
log.Info().Str("model", modelConfig.Model).
|
||||||
|
Str("baseURL", modelConfig.BaseURL).
|
||||||
|
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
|
||||||
|
Str("timeout", defaultTimeout.String()).
|
||||||
|
Msg("get model config")
|
||||||
|
|
||||||
|
return modelConfig, nil
|
||||||
|
}
|
||||||
79
uixt/ai/ai_openai.go
Normal file
79
uixt/ai/ai_openai.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
|
||||||
|
"github.com/getkin/kin-openapi/openapi3gen"
|
||||||
|
"github.com/httprunner/httprunner/v5/code"
|
||||||
|
"github.com/httprunner/httprunner/v5/internal/config"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
|
||||||
|
EnvOpenAIAPIKey = "OPENAI_API_KEY"
|
||||||
|
EnvModelName = "LLM_MODEL_NAME"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetOpenAIModelConfig get OpenAI config
|
||||||
|
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
|
||||||
|
if err := config.LoadEnv(); err != nil {
|
||||||
|
return nil, errors.Wrap(code.LoadEnvError, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
|
||||||
|
if openaiBaseURL == "" {
|
||||||
|
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
||||||
|
"env %s missed", EnvOpenAIBaseURL)
|
||||||
|
}
|
||||||
|
openaiAPIKey := os.Getenv(EnvOpenAIAPIKey)
|
||||||
|
if openaiAPIKey == "" {
|
||||||
|
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
||||||
|
"env %s missed", EnvOpenAIAPIKey)
|
||||||
|
}
|
||||||
|
modelName := os.Getenv(EnvModelName)
|
||||||
|
if modelName == "" {
|
||||||
|
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
||||||
|
"env %s missed", EnvModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutputFormat struct {
|
||||||
|
Thought string `json:"thought"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelConfig := &openai.ChatModelConfig{
|
||||||
|
BaseURL: openaiBaseURL,
|
||||||
|
APIKey: openaiAPIKey,
|
||||||
|
Model: modelName,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
// set structured response format
|
||||||
|
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
|
||||||
|
ResponseFormat: &openai2.ChatCompletionResponseFormat{
|
||||||
|
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
|
||||||
|
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
|
||||||
|
Name: "thought_and_action",
|
||||||
|
Description: "data that describes planning thought and action",
|
||||||
|
Schema: outputFormatSchema.Value,
|
||||||
|
Strict: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// log config info
|
||||||
|
log.Info().Str("model", modelConfig.Model).
|
||||||
|
Str("baseURL", modelConfig.BaseURL).
|
||||||
|
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
|
||||||
|
Str("timeout", defaultTimeout.String()).
|
||||||
|
Msg("get model config")
|
||||||
|
|
||||||
|
return modelConfig, nil
|
||||||
|
}
|
||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||||
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
"github.com/cloudwego/eino/components/model"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/httprunner/httprunner/v5/code"
|
"github.com/httprunner/httprunner/v5/code"
|
||||||
"github.com/httprunner/httprunner/v5/internal/json"
|
"github.com/httprunner/httprunner/v5/internal/json"
|
||||||
@@ -16,60 +18,11 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IAsserter interface defines the contract for assertion operations
|
||||||
type IAsserter interface {
|
type IAsserter interface {
|
||||||
Assert(opts *AssertOptions) (*AssertionResponse, error)
|
Assert(opts *AssertOptions) (*AssertionResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UI-TARS assertion system prompt
|
|
||||||
const uiTarsAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy.
|
|
||||||
|
|
||||||
## Output Json String Format
|
|
||||||
` + "```" + `
|
|
||||||
"{
|
|
||||||
"pass": <<is a boolean value from the enum [true, false], true means the assertion is truthy>>,
|
|
||||||
"thought": "<<is a string, give the reason why the assertion is falsy or truthy. Otherwise.>>"
|
|
||||||
}"
|
|
||||||
` + "```" + `
|
|
||||||
|
|
||||||
## Rules **MUST** follow
|
|
||||||
- Make sure to return **only** the JSON, with **no additional** text or explanations.
|
|
||||||
- Use Chinese in 'thought' part.
|
|
||||||
- You **MUST** strictly follow up the **Output Json String Format**.`
|
|
||||||
|
|
||||||
// AssertionResponse represents the response from an AI assertion
|
|
||||||
type AssertionResponse struct {
|
|
||||||
Pass bool `json:"pass"`
|
|
||||||
Thought string `json:"thought"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UITarsAsserter handles assertion using UI-TARS VLM
|
|
||||||
type UITarsAsserter struct {
|
|
||||||
ctx context.Context
|
|
||||||
model *ark.ChatModel
|
|
||||||
config *ark.ChatModelConfig
|
|
||||||
systemPrompt string
|
|
||||||
history ConversationHistory
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewUITarsAsserter creates a new UITarsAsserter instance
|
|
||||||
func NewUITarsAsserter(ctx context.Context) (*UITarsAsserter, error) {
|
|
||||||
config, err := GetArkModelConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
chatModel, err := ark.NewChatModel(ctx, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UITarsAsserter{
|
|
||||||
ctx: ctx,
|
|
||||||
config: config,
|
|
||||||
model: chatModel,
|
|
||||||
systemPrompt: uiTarsAssertionPrompt,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssertOptions represents the input options for assertion
|
// AssertOptions represents the input options for assertion
|
||||||
type AssertOptions struct {
|
type AssertOptions struct {
|
||||||
Assertion string `json:"assertion"` // The assertion text to verify
|
Assertion string `json:"assertion"` // The assertion text to verify
|
||||||
@@ -77,18 +30,65 @@ type AssertOptions struct {
|
|||||||
Size types.Size `json:"size"` // Screen dimensions
|
Size types.Size `json:"size"` // Screen dimensions
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateAssertionInput(opts *AssertOptions) error {
|
// AssertionResponse represents the response from an AI assertion
|
||||||
if opts.Assertion == "" {
|
type AssertionResponse struct {
|
||||||
return errors.Wrap(code.LLMPrepareRequestError, "assertion text is required")
|
Pass bool `json:"pass"`
|
||||||
|
Thought string `json:"thought"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asserter handles assertion using different AI models
|
||||||
|
type Asserter struct {
|
||||||
|
ctx context.Context
|
||||||
|
model model.ToolCallingChatModel
|
||||||
|
systemPrompt string
|
||||||
|
history ConversationHistory
|
||||||
|
modelType LLMServiceType
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAsserter creates a new Asserter instance
|
||||||
|
func NewAsserter(ctx context.Context, modelType LLMServiceType) (*Asserter, error) {
|
||||||
|
asserter := &Asserter{
|
||||||
|
ctx: ctx,
|
||||||
|
modelType: modelType,
|
||||||
|
systemPrompt: getAssertionSystemPrompt(modelType),
|
||||||
}
|
}
|
||||||
if opts.Screenshot == "" {
|
|
||||||
return errors.Wrap(code.LLMPrepareRequestError, "screenshot is required")
|
switch modelType {
|
||||||
|
case LLMServiceTypeUITARS:
|
||||||
|
config, err := GetArkModelConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
asserter.model, err = ark.NewChatModel(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case LLMServiceTypeGPT4Vision, LLMServiceTypeGPT4o:
|
||||||
|
config, err := GetOpenAIModelConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
asserter.model, err = openai.NewChatModel(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, errors.New("not supported model type for asserter")
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return asserter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAssertionSystemPrompt returns the appropriate system prompt for the given model type
|
||||||
|
func getAssertionSystemPrompt(modelType LLMServiceType) string {
|
||||||
|
if modelType == LLMServiceTypeUITARS {
|
||||||
|
return defaultAssertionPrompt + "\n\n" + uiTarsAssertionResponseFormat
|
||||||
|
}
|
||||||
|
return defaultAssertionPrompt + "\n\n" + defaultAssertionResponseJsonFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert performs the assertion check on the screenshot
|
// Assert performs the assertion check on the screenshot
|
||||||
func (a *UITarsAsserter) Assert(opts *AssertOptions) (*AssertionResponse, error) {
|
func (a *Asserter) Assert(opts *AssertOptions) (*AssertionResponse, error) {
|
||||||
// Validate input parameters
|
// Validate input parameters
|
||||||
if err := validateAssertionInput(opts); err != nil {
|
if err := validateAssertionInput(opts); err != nil {
|
||||||
return nil, errors.Wrap(err, "validate assertion parameters failed")
|
return nil, errors.Wrap(err, "validate assertion parameters failed")
|
||||||
@@ -133,7 +133,7 @@ Here is the assertion. Please tell whether it is truthy according to the screens
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
resp, err := a.model.Generate(a.ctx, a.history)
|
resp, err := a.model.Generate(a.ctx, a.history)
|
||||||
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
||||||
Str("model", a.config.Model).Msg("call model service for assertion")
|
Str("model", string(a.modelType)).Msg("call model service for assertion")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
||||||
}
|
}
|
||||||
@@ -154,78 +154,36 @@ Here is the assertion. Please tell whether it is truthy according to the screens
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAssertionResult 解析模型返回的JSON响应
|
// validateAssertionInput validates the input parameters for assertion
|
||||||
|
func validateAssertionInput(opts *AssertOptions) error {
|
||||||
|
if opts.Assertion == "" {
|
||||||
|
return errors.Wrap(code.LLMPrepareRequestError, "assertion text is required")
|
||||||
|
}
|
||||||
|
if opts.Screenshot == "" {
|
||||||
|
return errors.Wrap(code.LLMPrepareRequestError, "screenshot is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAssertionResult parses the model response into AssertionResponse
|
||||||
func parseAssertionResult(content string) (*AssertionResponse, error) {
|
func parseAssertionResult(content string) (*AssertionResponse, error) {
|
||||||
// 1. 从响应中提取JSON内容
|
// Extract JSON content from response
|
||||||
jsonContent := extractJSON(content)
|
jsonContent := extractJSON(content)
|
||||||
if jsonContent == "" {
|
if jsonContent == "" {
|
||||||
return nil, errors.New("could not extract JSON from response")
|
return nil, errors.New("could not extract JSON from response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 预处理和标准解析尝试
|
// Parse JSON response
|
||||||
jsonContent = prepareJSON(jsonContent)
|
|
||||||
var result AssertionResponse
|
var result AssertionResponse
|
||||||
if err := json.Unmarshal([]byte(jsonContent), &result); err == nil {
|
if err := json.Unmarshal([]byte(jsonContent), &result); err != nil {
|
||||||
return &result, nil
|
return nil, errors.Wrap(code.LLMParseAssertionResponseError, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 备用:正则表达式解析
|
return &result, nil
|
||||||
if pass, thought := extractWithRegex(jsonContent); thought != "" {
|
|
||||||
return &AssertionResponse{Pass: pass, Thought: thought}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("failed to parse assertion result")
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareJSON 预处理JSON字符串,修复常见问题
|
|
||||||
func prepareJSON(jsonStr string) string {
|
|
||||||
// 1. 去除可能的外层引号
|
|
||||||
jsonStr = strings.TrimSpace(jsonStr)
|
|
||||||
if strings.HasPrefix(jsonStr, "\"") && strings.HasSuffix(jsonStr, "\"") {
|
|
||||||
jsonStr = jsonStr[1 : len(jsonStr)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 转义thought内容中的引号
|
|
||||||
thoughtRegex := regexp.MustCompile(`"thought":\s*"([^"]*)"`)
|
|
||||||
matches := thoughtRegex.FindStringSubmatch(jsonStr)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
thoughtValue := matches[1]
|
|
||||||
fixedThought := strings.ReplaceAll(thoughtValue, "\"", "\\\"")
|
|
||||||
jsonStr = strings.Replace(jsonStr, matches[0], fmt.Sprintf(`"thought": "%s"`, fixedThought), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 处理换行和特殊字符
|
|
||||||
jsonStr = strings.ReplaceAll(jsonStr, "\n", "\\n")
|
|
||||||
jsonStr = strings.ReplaceAll(jsonStr, "\r", "\\r")
|
|
||||||
jsonStr = strings.ReplaceAll(jsonStr, "\t", "\\t")
|
|
||||||
|
|
||||||
return jsonStr
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractWithRegex 使用正则表达式提取pass和thought值
|
|
||||||
func extractWithRegex(jsonStr string) (pass bool, thought string) {
|
|
||||||
// 提取pass值
|
|
||||||
passRegex := regexp.MustCompile(`"pass":\s*(true|false)`)
|
|
||||||
passMatches := passRegex.FindStringSubmatch(jsonStr)
|
|
||||||
|
|
||||||
// 提取thought值
|
|
||||||
thoughtRegex := regexp.MustCompile(`"thought":\s*"([^"]*(?:"[^"]*)*)"`)
|
|
||||||
thoughtMatches := thoughtRegex.FindStringSubmatch(jsonStr)
|
|
||||||
|
|
||||||
if len(passMatches) > 1 && len(thoughtMatches) > 1 {
|
|
||||||
// 处理提取的值
|
|
||||||
pass = passMatches[1] == "true"
|
|
||||||
thought = strings.ReplaceAll(thoughtMatches[1], "\\\"", "\"")
|
|
||||||
thought = strings.ReplaceAll(thought, "\\\\", "\\")
|
|
||||||
return pass, thought
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractJSON extracts JSON content from a string that might contain markdown or other formatting
|
// extractJSON extracts JSON content from a string that might contain markdown or other formatting
|
||||||
func extractJSON(content string) string {
|
func extractJSON(content string) string {
|
||||||
// Try to extract JSON directly
|
|
||||||
content = strings.TrimSpace(content)
|
content = strings.TrimSpace(content)
|
||||||
|
|
||||||
// If the content is already a valid JSON, return it
|
// If the content is already a valid JSON, return it
|
||||||
@@ -233,7 +191,7 @@ func extractJSON(content string) string {
|
|||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for markdown code blocks with more flexible pattern
|
// Try to extract JSON from markdown code blocks
|
||||||
jsonRegex := regexp.MustCompile(`(?:json)?\s*({[\s\S]*?})\s*`)
|
jsonRegex := regexp.MustCompile(`(?:json)?\s*({[\s\S]*?})\s*`)
|
||||||
matches := jsonRegex.FindStringSubmatch(content)
|
matches := jsonRegex.FindStringSubmatch(content)
|
||||||
if len(matches) > 1 {
|
if len(matches) > 1 {
|
||||||
@@ -241,7 +199,6 @@ func extractJSON(content string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try a more robust approach for JSON with Chinese characters
|
// Try a more robust approach for JSON with Chinese characters
|
||||||
// First look for the outermost pair of curly braces
|
|
||||||
startIdx := strings.Index(content, "{")
|
startIdx := strings.Index(content, "{")
|
||||||
if startIdx >= 0 {
|
if startIdx >= 0 {
|
||||||
depth := 1
|
depth := 1
|
||||||
@@ -251,19 +208,11 @@ func extractJSON(content string) string {
|
|||||||
} else if content[i] == '}' {
|
} else if content[i] == '}' {
|
||||||
depth--
|
depth--
|
||||||
if depth == 0 {
|
if depth == 0 {
|
||||||
// Found the closing brace
|
|
||||||
return content[startIdx : i+1]
|
return content[startIdx : i+1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to regex approach
|
return content
|
||||||
braceRegex := regexp.MustCompile(`{[\s\S]*?}`)
|
|
||||||
matches = braceRegex.FindStringSubmatch(content)
|
|
||||||
if len(matches) > 0 {
|
|
||||||
return strings.TrimSpace(matches[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
25
uixt/ai/asserter_prompts.go
Normal file
25
uixt/ai/asserter_prompts.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
// Default assertion system prompt
|
||||||
|
const defaultAssertionPrompt = `You are a senior testing engineer. User will give an assertion and a screenshot of a page. By carefully viewing the screenshot, please tell whether the assertion is truthy.`
|
||||||
|
|
||||||
|
// Default assertion response format
|
||||||
|
const defaultAssertionResponseJsonFormat = `Return in the following JSON format:
|
||||||
|
{
|
||||||
|
pass: boolean, // whether the assertion is truthy
|
||||||
|
thought: string | null, // string, if the result is falsy, give the reason why it is falsy. Otherwise, put null.
|
||||||
|
}`
|
||||||
|
|
||||||
|
// UI-TARS assertion response format
|
||||||
|
const uiTarsAssertionResponseFormat = `## Output Json String Format
|
||||||
|
` + "```" + `
|
||||||
|
"{
|
||||||
|
"pass": <<is a boolean value from the enum [true, false], true means the assertion is truthy>>,
|
||||||
|
"thought": "<<is a string, give the reason why the assertion is falsy or truthy. Otherwise.>>"
|
||||||
|
}"
|
||||||
|
` + "```" + `
|
||||||
|
|
||||||
|
## Rules **MUST** follow
|
||||||
|
- Make sure to return **only** the JSON, with **no additional** text or explanations.
|
||||||
|
- Use Chinese in ` + "`Thought`" + ` part.
|
||||||
|
- You **MUST** strictly follow up the **Output Json String Format**.`
|
||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/httprunner/httprunner/v5/code"
|
"github.com/httprunner/httprunner/v5/code"
|
||||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPlanner interface {
|
type IPlanner interface {
|
||||||
@@ -85,100 +84,6 @@ func validatePlanningInput(opts *PlanningOptions) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func logRequest(messages []*schema.Message) {
|
|
||||||
msgs := make([]*schema.Message, 0, len(messages))
|
|
||||||
for _, message := range messages {
|
|
||||||
msg := &schema.Message{
|
|
||||||
Role: message.Role,
|
|
||||||
}
|
|
||||||
if message.Content != "" {
|
|
||||||
msg.Content = message.Content
|
|
||||||
} else if len(message.MultiContent) > 0 {
|
|
||||||
for _, mc := range message.MultiContent {
|
|
||||||
switch mc.Type {
|
|
||||||
case schema.ChatMessagePartTypeImageURL:
|
|
||||||
// Create a copy of the ImageURL to avoid modifying the original message
|
|
||||||
imageURLCopy := *mc.ImageURL
|
|
||||||
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
|
|
||||||
imageURLCopy.URL = "<data:image/base64...>"
|
|
||||||
}
|
|
||||||
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
|
|
||||||
Type: mc.Type,
|
|
||||||
ImageURL: &imageURLCopy,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msgs = append(msgs, msg)
|
|
||||||
}
|
|
||||||
log.Debug().Interface("messages", msgs).Msg("log request messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
func logResponse(resp *schema.Message) {
|
|
||||||
logger := log.Info().Str("role", string(resp.Role)).
|
|
||||||
Str("content", resp.Content)
|
|
||||||
if resp.ResponseMeta != nil {
|
|
||||||
logger = logger.Interface("response_meta", resp.ResponseMeta)
|
|
||||||
}
|
|
||||||
if resp.Extra != nil {
|
|
||||||
logger = logger.Interface("extra", resp.Extra)
|
|
||||||
}
|
|
||||||
logger.Msg("log response message")
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConversationHistory []*schema.Message
|
|
||||||
|
|
||||||
// Append adds a message to the conversation history
|
|
||||||
func (h *ConversationHistory) Append(msg *schema.Message) {
|
|
||||||
// for user image message:
|
|
||||||
// - keep at most 4 user image messages
|
|
||||||
// - delete the oldest user image message when the limit is reached
|
|
||||||
if msg.Role == schema.User {
|
|
||||||
// get all existing user messages
|
|
||||||
userImgCount := 0
|
|
||||||
firstUserImgIndex := -1
|
|
||||||
|
|
||||||
// calculate the number of user messages and find the index of the first user message
|
|
||||||
for i, item := range *h {
|
|
||||||
if item.Role == schema.User {
|
|
||||||
userImgCount++
|
|
||||||
if firstUserImgIndex == -1 {
|
|
||||||
firstUserImgIndex = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there are already 4 user messages, delete the first one before adding the new message
|
|
||||||
if userImgCount >= 4 && firstUserImgIndex >= 0 {
|
|
||||||
// delete the first user message
|
|
||||||
*h = append(
|
|
||||||
(*h)[:firstUserImgIndex],
|
|
||||||
(*h)[firstUserImgIndex+1:]...,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
// add the new user message to the history
|
|
||||||
*h = append(*h, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for assistant message:
|
|
||||||
// - keep at most the last 10 assistant messages
|
|
||||||
if msg.Role == schema.Assistant {
|
|
||||||
// add the new assistant message to the history
|
|
||||||
*h = append(*h, msg)
|
|
||||||
|
|
||||||
// if there are more than 10 assistant messages, remove the oldest ones
|
|
||||||
assistantMsgCount := 0
|
|
||||||
for i := len(*h) - 1; i >= 0; i-- {
|
|
||||||
if (*h)[i].Role == schema.Assistant {
|
|
||||||
assistantMsgCount++
|
|
||||||
if assistantMsgCount > 10 {
|
|
||||||
*h = append((*h)[:i], (*h)[i+1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SavePositionImg saves an image with position markers
|
// SavePositionImg saves an image with position markers
|
||||||
func SavePositionImg(params struct {
|
func SavePositionImg(params struct {
|
||||||
InputImgBase64 string
|
InputImgBase64 string
|
||||||
|
|||||||
@@ -4,90 +4,20 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
openai2 "github.com/cloudwego/eino-ext/libs/acl/openai"
|
|
||||||
"github.com/cloudwego/eino/components/model"
|
"github.com/cloudwego/eino/components/model"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/getkin/kin-openapi/openapi3gen"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/httprunner/httprunner/v5/code"
|
"github.com/httprunner/httprunner/v5/code"
|
||||||
"github.com/httprunner/httprunner/v5/internal/config"
|
|
||||||
"github.com/httprunner/httprunner/v5/internal/json"
|
"github.com/httprunner/httprunner/v5/internal/json"
|
||||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
EnvOpenAIBaseURL = "OPENAI_BASE_URL"
|
|
||||||
EnvOpenAIAPIKey = "OPENAI_API_KEY"
|
|
||||||
EnvModelName = "LLM_MODEL_NAME"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetOpenAIModelConfig get OpenAI config
|
|
||||||
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
|
|
||||||
if err := config.LoadEnv(); err != nil {
|
|
||||||
return nil, errors.Wrap(code.LoadEnvError, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
|
|
||||||
if openaiBaseURL == "" {
|
|
||||||
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
|
||||||
"env %s missed", EnvOpenAIBaseURL)
|
|
||||||
}
|
|
||||||
openaiAPIKey := os.Getenv(EnvOpenAIAPIKey)
|
|
||||||
if openaiAPIKey == "" {
|
|
||||||
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
|
||||||
"env %s missed", EnvOpenAIAPIKey)
|
|
||||||
}
|
|
||||||
modelName := os.Getenv(EnvModelName)
|
|
||||||
if modelName == "" {
|
|
||||||
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
|
||||||
"env %s missed", EnvModelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
type OutputFormat struct {
|
|
||||||
Thought string `json:"thought"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
modelConfig := &openai.ChatModelConfig{
|
|
||||||
BaseURL: openaiBaseURL,
|
|
||||||
APIKey: openaiAPIKey,
|
|
||||||
Model: modelName,
|
|
||||||
Timeout: defaultTimeout,
|
|
||||||
// set structured response format
|
|
||||||
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
|
|
||||||
ResponseFormat: &openai2.ChatCompletionResponseFormat{
|
|
||||||
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
|
|
||||||
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
|
|
||||||
Name: "thought_and_action",
|
|
||||||
Description: "data that describes planning thought and action",
|
|
||||||
Schema: outputFormatSchema.Value,
|
|
||||||
Strict: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// log config info
|
|
||||||
log.Info().Str("model", modelConfig.Model).
|
|
||||||
Str("baseURL", modelConfig.BaseURL).
|
|
||||||
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
|
|
||||||
Str("timeout", defaultTimeout.String()).
|
|
||||||
Msg("get model config")
|
|
||||||
|
|
||||||
return modelConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPlanner(ctx context.Context) (*Planner, error) {
|
func NewPlanner(ctx context.Context) (*Planner, error) {
|
||||||
config, err := GetOpenAIModelConfig()
|
config, err := GetOpenAIModelConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,8 +29,8 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
|
|||||||
}
|
}
|
||||||
return &Planner{
|
return &Planner{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
|
||||||
model: model,
|
model: model,
|
||||||
|
modelType: LLMServiceTypeGPT4o,
|
||||||
systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling
|
systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -108,8 +38,8 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
|
|||||||
type Planner struct {
|
type Planner struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
model model.ToolCallingChatModel
|
model model.ToolCallingChatModel
|
||||||
config *openai.ChatModelConfig
|
|
||||||
systemPrompt string
|
systemPrompt string
|
||||||
|
modelType LLMServiceType
|
||||||
history ConversationHistory
|
history ConversationHistory
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +69,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
resp, err := p.model.Generate(p.ctx, p.history)
|
resp, err := p.model.Generate(p.ctx, p.history)
|
||||||
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
||||||
Str("model", p.config.Model).Msg("call model service")
|
Str("model", string(p.modelType)).Msg("call model service")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
29
uixt/ai/planner_prompts.go
Normal file
29
uixt/ai/planner_prompts.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
// https://www.volcengine.com/docs/82379/1536429
|
||||||
|
const uiTarsPlanningPrompt = `
|
||||||
|
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
` + "```" + `
|
||||||
|
Thought: ...
|
||||||
|
Action: ...
|
||||||
|
` + "```" + `
|
||||||
|
|
||||||
|
## Action Space
|
||||||
|
click(start_box='[x1, y1, x2, y2]')
|
||||||
|
left_double(start_box='[x1, y1, x2, y2]')
|
||||||
|
right_single(start_box='[x1, y1, x2, y2]')
|
||||||
|
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
||||||
|
hotkey(key='')
|
||||||
|
type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `.
|
||||||
|
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
||||||
|
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||||
|
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||||
|
|
||||||
|
## Note
|
||||||
|
- Use Chinese in ` + "`Thought`" + ` part.
|
||||||
|
- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part.
|
||||||
|
|
||||||
|
## User Instruction
|
||||||
|
`
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,64 +13,12 @@ import (
|
|||||||
"github.com/cloudwego/eino/components/model"
|
"github.com/cloudwego/eino/components/model"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/httprunner/httprunner/v5/code"
|
"github.com/httprunner/httprunner/v5/code"
|
||||||
"github.com/httprunner/httprunner/v5/internal/config"
|
|
||||||
"github.com/httprunner/httprunner/v5/internal/json"
|
"github.com/httprunner/httprunner/v5/internal/json"
|
||||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
EnvArkBaseURL = "ARK_BASE_URL"
|
|
||||||
EnvArkAPIKey = "ARK_API_KEY"
|
|
||||||
EnvArkModelID = "ARK_MODEL_ID"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetArkModelConfig() (*ark.ChatModelConfig, error) {
|
|
||||||
if err := config.LoadEnv(); err != nil {
|
|
||||||
return nil, errors.Wrap(code.LoadEnvError, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
arkBaseURL := os.Getenv(EnvArkBaseURL)
|
|
||||||
arkAPIKey := os.Getenv(EnvArkAPIKey)
|
|
||||||
if arkAPIKey == "" {
|
|
||||||
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
|
||||||
"env %s missed", EnvArkAPIKey)
|
|
||||||
}
|
|
||||||
modelName := os.Getenv(EnvArkModelID)
|
|
||||||
if modelName == "" {
|
|
||||||
return nil, errors.Wrapf(code.LLMEnvMissedError,
|
|
||||||
"env %s missed", EnvArkModelID)
|
|
||||||
}
|
|
||||||
timeout := defaultTimeout
|
|
||||||
|
|
||||||
// https://www.volcengine.com/docs/82379/1494384?redirect=1
|
|
||||||
temperature := float32(0.01) // [0, 2] 采样温度。控制了生成文本时对每个候选词的概率分布进行平滑的程度。
|
|
||||||
// topP := float32(0.7) // [0, 1] 核采样概率阈值。模型会考虑概率质量在 top_p 内的 token 结果。
|
|
||||||
// maxTokens := int(4096) // 模型可以生成的最大 token 数量。输入 token 和输出 token 的总长度还受模型的上下文长度限制。
|
|
||||||
// frequencyPenalty := float32(0) // [-2, 2] 频率惩罚系数。如果值为正,会根据新 token 在文本中的出现频率对其进行惩罚,从而降低模型逐字重复的可能性。
|
|
||||||
|
|
||||||
modelConfig := &ark.ChatModelConfig{
|
|
||||||
BaseURL: arkBaseURL,
|
|
||||||
APIKey: arkAPIKey,
|
|
||||||
Model: modelName,
|
|
||||||
Timeout: &timeout,
|
|
||||||
Temperature: &temperature,
|
|
||||||
// TopP: &topP,
|
|
||||||
// MaxTokens: &maxTokens,
|
|
||||||
// FrequencyPenalty: &frequencyPenalty,
|
|
||||||
}
|
|
||||||
|
|
||||||
// log config info
|
|
||||||
log.Info().Str("model", modelConfig.Model).
|
|
||||||
Str("baseURL", modelConfig.BaseURL).
|
|
||||||
Str("apiKey", maskAPIKey(modelConfig.APIKey)).
|
|
||||||
Str("timeout", defaultTimeout.String()).
|
|
||||||
Msg("get model config")
|
|
||||||
|
|
||||||
return modelConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) {
|
func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) {
|
||||||
config, err := GetArkModelConfig()
|
config, err := GetArkModelConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -84,45 +31,17 @@ func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) {
|
|||||||
|
|
||||||
return &UITarsPlanner{
|
return &UITarsPlanner{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
|
||||||
model: chatModel,
|
model: chatModel,
|
||||||
|
modelType: LLMServiceTypeUITARS,
|
||||||
systemPrompt: uiTarsPlanningPrompt,
|
systemPrompt: uiTarsPlanningPrompt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://www.volcengine.com/docs/82379/1536429
|
|
||||||
const uiTarsPlanningPrompt = `
|
|
||||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
|
||||||
|
|
||||||
## Output Format
|
|
||||||
` + "```" + `
|
|
||||||
Thought: ...
|
|
||||||
Action: ...
|
|
||||||
` + "```" + `
|
|
||||||
|
|
||||||
## Action Space
|
|
||||||
click(start_box='[x1, y1, x2, y2]')
|
|
||||||
left_double(start_box='[x1, y1, x2, y2]')
|
|
||||||
right_single(start_box='[x1, y1, x2, y2]')
|
|
||||||
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
|
||||||
hotkey(key='')
|
|
||||||
type(content='') #If you want to submit your input, use "\n" at the end of ` + "`content`" + `.
|
|
||||||
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
|
||||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
|
||||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
|
||||||
|
|
||||||
## Note
|
|
||||||
- Use Chinese in ` + "`Thought`" + ` part.
|
|
||||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in ` + "`Thought`" + ` part.
|
|
||||||
|
|
||||||
## User Instruction
|
|
||||||
`
|
|
||||||
|
|
||||||
type UITarsPlanner struct {
|
type UITarsPlanner struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
model model.ToolCallingChatModel
|
model model.ToolCallingChatModel
|
||||||
config *ark.ChatModelConfig
|
|
||||||
systemPrompt string
|
systemPrompt string
|
||||||
|
modelType LLMServiceType
|
||||||
history ConversationHistory
|
history ConversationHistory
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,7 +71,7 @@ func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
resp, err := p.model.Generate(p.ctx, p.history)
|
resp, err := p.model.Generate(p.ctx, p.history)
|
||||||
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
||||||
Str("model", p.config.Model).Msg("call model service")
|
Str("model", string(p.modelType)).Msg("call model service")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
103
uixt/ai/session.go
Normal file
103
uixt/ai/session.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationHistory represents a sequence of chat messages
|
||||||
|
type ConversationHistory []*schema.Message
|
||||||
|
|
||||||
|
// Append adds a new message to the conversation history
|
||||||
|
func (h *ConversationHistory) Append(msg *schema.Message) {
|
||||||
|
// for user image message:
|
||||||
|
// - keep at most 4 user image messages
|
||||||
|
// - delete the oldest user image message when the limit is reached
|
||||||
|
if msg.Role == schema.User {
|
||||||
|
// get all existing user messages
|
||||||
|
userImgCount := 0
|
||||||
|
firstUserImgIndex := -1
|
||||||
|
|
||||||
|
// calculate the number of user messages and find the index of the first user message
|
||||||
|
for i, item := range *h {
|
||||||
|
if item.Role == schema.User {
|
||||||
|
userImgCount++
|
||||||
|
if firstUserImgIndex == -1 {
|
||||||
|
firstUserImgIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are already 4 user messages, delete the first one before adding the new message
|
||||||
|
if userImgCount >= 4 && firstUserImgIndex >= 0 {
|
||||||
|
// delete the first user message
|
||||||
|
*h = append(
|
||||||
|
(*h)[:firstUserImgIndex],
|
||||||
|
(*h)[firstUserImgIndex+1:]...,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// add the new user message to the history
|
||||||
|
*h = append(*h, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for assistant message:
|
||||||
|
// - keep at most the last 10 assistant messages
|
||||||
|
if msg.Role == schema.Assistant {
|
||||||
|
// add the new assistant message to the history
|
||||||
|
*h = append(*h, msg)
|
||||||
|
|
||||||
|
// if there are more than 10 assistant messages, remove the oldest ones
|
||||||
|
assistantMsgCount := 0
|
||||||
|
for i := len(*h) - 1; i >= 0; i-- {
|
||||||
|
if (*h)[i].Role == schema.Assistant {
|
||||||
|
assistantMsgCount++
|
||||||
|
if assistantMsgCount > 10 {
|
||||||
|
*h = append((*h)[:i], (*h)[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logRequest(messages ConversationHistory) {
|
||||||
|
msgs := make(ConversationHistory, 0, len(messages))
|
||||||
|
for _, message := range messages {
|
||||||
|
msg := &schema.Message{
|
||||||
|
Role: message.Role,
|
||||||
|
}
|
||||||
|
if message.Content != "" {
|
||||||
|
msg.Content = message.Content
|
||||||
|
} else if len(message.MultiContent) > 0 {
|
||||||
|
for _, mc := range message.MultiContent {
|
||||||
|
switch mc.Type {
|
||||||
|
case schema.ChatMessagePartTypeImageURL:
|
||||||
|
// Create a copy of the ImageURL to avoid modifying the original message
|
||||||
|
imageURLCopy := *mc.ImageURL
|
||||||
|
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
|
||||||
|
imageURLCopy.URL = "<data:image/base64...>"
|
||||||
|
}
|
||||||
|
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
|
||||||
|
Type: mc.Type,
|
||||||
|
ImageURL: &imageURLCopy,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
}
|
||||||
|
log.Debug().Interface("messages", msgs).Msg("log request messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func logResponse(resp *schema.Message) {
|
||||||
|
logger := log.Info().Str("role", string(resp.Role)).
|
||||||
|
Str("content", resp.Content)
|
||||||
|
if resp.ResponseMeta != nil {
|
||||||
|
logger = logger.Interface("response_meta", resp.ResponseMeta)
|
||||||
|
}
|
||||||
|
if resp.Extra != nil {
|
||||||
|
logger = logger.Interface("extra", resp.Extra)
|
||||||
|
}
|
||||||
|
logger.Msg("log response message")
|
||||||
|
}
|
||||||
BIN
uixt/ai/testdata/llk_4.png
vendored
Normal file
BIN
uixt/ai/testdata/llk_4.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 585 KiB |
Reference in New Issue
Block a user