mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-07 06:22:43 +08:00
refactor: GetModelConfig
This commit is contained in:
@@ -1 +1 @@
|
||||
v5.0.0-beta-2504301431
|
||||
v5.0.0-beta-2504301521
|
||||
|
||||
@@ -20,11 +20,16 @@ type ILLMService interface {
|
||||
}
|
||||
|
||||
func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) {
|
||||
planner, err := NewPlanner(context.Background(), modelType)
|
||||
modelConfig, err := GetModelConfig(modelType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
asserter, err := NewAsserter(context.Background())
|
||||
|
||||
planner, err := NewPlanner(context.Background(), modelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
asserter, err := NewAsserter(context.Background(), modelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -36,6 +41,7 @@ func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) {
|
||||
}
|
||||
|
||||
// combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能
|
||||
// ⭐️支持采用不同的模型服务进行规划和断言
|
||||
type combinedLLMService struct {
|
||||
planner IPlanner // 提供规划功能
|
||||
asserter IAsserter // 提供断言功能
|
||||
@@ -58,18 +64,20 @@ const (
|
||||
EnvModelName = "LLM_MODEL_NAME"
|
||||
)
|
||||
|
||||
var EnvModelUse string
|
||||
|
||||
const (
|
||||
defaultTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// GetOpenAIModelConfig get OpenAI config
|
||||
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
|
||||
type ModelConfig struct {
|
||||
*openai.ChatModelConfig
|
||||
ModelType option.LLMServiceType
|
||||
}
|
||||
|
||||
// GetModelConfig get OpenAI config
|
||||
func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
|
||||
if err := config.LoadEnv(); err != nil {
|
||||
return nil, errors.Wrap(code.LoadEnvError, err.Error())
|
||||
}
|
||||
EnvModelUse = os.Getenv("LLM_MODEL_USE")
|
||||
|
||||
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
|
||||
if openaiBaseURL == "" {
|
||||
@@ -103,7 +111,10 @@ func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
|
||||
Str("timeout", defaultTimeout.String()).
|
||||
Msg("get model config")
|
||||
|
||||
return modelConfig, nil
|
||||
return &ModelConfig{
|
||||
ChatModelConfig: modelConfig,
|
||||
ModelType: modelType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// maskAPIKey masks the API key
|
||||
|
||||
@@ -41,28 +41,23 @@ type AssertionResponse struct {
|
||||
// Asserter handles assertion using different AI models
|
||||
type Asserter struct {
|
||||
ctx context.Context
|
||||
modelConfig *ModelConfig
|
||||
model model.ToolCallingChatModel
|
||||
systemPrompt string
|
||||
history ConversationHistory
|
||||
}
|
||||
|
||||
// NewAsserter creates a new Asserter instance
|
||||
func NewAsserter(ctx context.Context) (*Asserter, error) {
|
||||
func NewAsserter(ctx context.Context, modelConfig *ModelConfig) (*Asserter, error) {
|
||||
asserter := &Asserter{
|
||||
ctx: ctx,
|
||||
modelConfig: modelConfig,
|
||||
systemPrompt: defaultAssertionPrompt,
|
||||
}
|
||||
|
||||
config, err := GetOpenAIModelConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.Contains(EnvModelUse, string(option.LLMServiceTypeUITARS)) {
|
||||
if modelConfig.ModelType == option.LLMServiceTypeUITARS {
|
||||
asserter.systemPrompt += "\n\n" + uiTarsAssertionResponseFormat
|
||||
} else if strings.Contains(EnvModelUse, string(option.LLMServiceTypeQwenVL)) {
|
||||
asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat
|
||||
} else if strings.Contains(EnvModelUse, string(option.LLMServiceTypeGPT)) {
|
||||
} else if modelConfig.ModelType == option.LLMServiceTypeGPT {
|
||||
// define output format
|
||||
type OutputFormat struct {
|
||||
Thought string `json:"thought"`
|
||||
@@ -71,11 +66,11 @@ func NewAsserter(ctx context.Context) (*Asserter, error) {
|
||||
}
|
||||
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
// set structured response format
|
||||
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
|
||||
config.ResponseFormat = &openai2.ChatCompletionResponseFormat{
|
||||
modelConfig.ChatModelConfig.ResponseFormat = &openai2.ChatCompletionResponseFormat{
|
||||
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
|
||||
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
|
||||
Name: "assertion_result",
|
||||
@@ -85,12 +80,13 @@ func NewAsserter(ctx context.Context) (*Asserter, error) {
|
||||
},
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("model type %s not supported for asserter", EnvModelUse)
|
||||
asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat
|
||||
}
|
||||
|
||||
asserter.model, err = openai.NewChatModel(ctx, config)
|
||||
var err error
|
||||
asserter.model, err = openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
|
||||
return asserter, nil
|
||||
@@ -142,7 +138,7 @@ Here is the assertion. Please tell whether it is truthy according to the screens
|
||||
startTime := time.Now()
|
||||
resp, err := a.model.Generate(a.ctx, a.history)
|
||||
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
||||
Str("model", EnvModelUse).Msg("call model service for assertion")
|
||||
Str("model", string(a.modelConfig.ModelType)).Msg("call model service for assertion")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
||||
}
|
||||
|
||||
@@ -4,13 +4,16 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createAsserter(t *testing.T) *Asserter {
|
||||
asserter, err := NewAsserter(context.Background())
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
require.NoError(t, err)
|
||||
asserter, err := NewAsserter(context.Background(), modelConfig)
|
||||
require.NoError(t, err)
|
||||
return asserter
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
@@ -33,26 +32,22 @@ type PlanningResult struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner, error) {
|
||||
func NewPlanner(ctx context.Context, modelConfig *ModelConfig) (*Planner, error) {
|
||||
planner := &Planner{
|
||||
ctx: ctx,
|
||||
modelType: modelType,
|
||||
ctx: ctx,
|
||||
modelConfig: modelConfig,
|
||||
}
|
||||
|
||||
config, err := GetOpenAIModelConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
|
||||
}
|
||||
|
||||
if modelType == option.LLMServiceTypeUITARS {
|
||||
if modelConfig.ModelType == option.LLMServiceTypeUITARS {
|
||||
planner.systemPrompt = uiTarsPlanningPrompt
|
||||
} else {
|
||||
planner.systemPrompt = defaultPlanningResponseJsonFormat
|
||||
}
|
||||
|
||||
planner.model, err = openai.NewChatModel(ctx, config)
|
||||
var err error
|
||||
planner.model, err = openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
|
||||
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
|
||||
}
|
||||
|
||||
return planner, nil
|
||||
@@ -60,9 +55,9 @@ func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner,
|
||||
|
||||
type Planner struct {
|
||||
ctx context.Context
|
||||
modelConfig *ModelConfig
|
||||
model model.ToolCallingChatModel
|
||||
systemPrompt string
|
||||
modelType option.LLMServiceType
|
||||
history ConversationHistory
|
||||
}
|
||||
|
||||
@@ -76,11 +71,10 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
// prepare prompt
|
||||
if len(p.history) == 0 {
|
||||
// add system message
|
||||
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
||||
p.history = ConversationHistory{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: systemPrompt,
|
||||
Content: p.systemPrompt + opts.UserInstruction,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -92,7 +86,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
startTime := time.Now()
|
||||
resp, err := p.model.Generate(p.ctx, p.history)
|
||||
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
|
||||
Str("model", string(p.modelType)).Msg("call model service")
|
||||
Str("model", string(p.modelConfig.ModelType)).Msg("call model service")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
|
||||
}
|
||||
@@ -116,7 +110,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
|
||||
var parseActions []ParsedAction
|
||||
var err error
|
||||
if p.modelType == option.LLMServiceTypeUITARS {
|
||||
if p.modelConfig.ModelType == option.LLMServiceTypeUITARS {
|
||||
// parse Thought/Action format from UI-TARS
|
||||
parseActions, err = parseThoughtAction(msg.Content)
|
||||
if err != nil {
|
||||
|
||||
@@ -36,7 +36,10 @@ func TestVLMPlanning(t *testing.T) {
|
||||
|
||||
userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标"
|
||||
|
||||
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
require.NoError(t, err)
|
||||
|
||||
planner, err := NewPlanner(context.Background(), modelConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := &PlanningOptions{
|
||||
@@ -106,7 +109,10 @@ func TestXHSPlanning(t *testing.T) {
|
||||
|
||||
userInstruction := "点击第二个帖子的作者头像"
|
||||
|
||||
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
require.NoError(t, err)
|
||||
|
||||
planner, err := NewPlanner(context.Background(), modelConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := &PlanningOptions{
|
||||
@@ -176,7 +182,10 @@ func TestChatList(t *testing.T) {
|
||||
|
||||
userInstruction := "请结合图片的文字信息,请告诉我一共有多少个群聊,哪些群聊右下角有绿点"
|
||||
|
||||
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
require.NoError(t, err)
|
||||
|
||||
planner, err := NewPlanner(context.Background(), modelConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := &PlanningOptions{
|
||||
@@ -207,7 +216,10 @@ func TestHandleSwitch(t *testing.T) {
|
||||
userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关
|
||||
// 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启
|
||||
|
||||
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
|
||||
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
|
||||
require.NoError(t, err)
|
||||
|
||||
planner, err := NewPlanner(context.Background(), modelConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
Reference in New Issue
Block a user