refactor: GetModelConfig

This commit is contained in:
lilong.129
2025-04-30 15:17:01 +08:00
parent 0e9389c796
commit fcddcfb630
6 changed files with 63 additions and 47 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2504301431
v5.0.0-beta-2504301521

View File

@@ -20,11 +20,16 @@ type ILLMService interface {
}
func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) {
planner, err := NewPlanner(context.Background(), modelType)
modelConfig, err := GetModelConfig(modelType)
if err != nil {
return nil, err
}
asserter, err := NewAsserter(context.Background())
planner, err := NewPlanner(context.Background(), modelConfig)
if err != nil {
return nil, err
}
asserter, err := NewAsserter(context.Background(), modelConfig)
if err != nil {
return nil, err
}
@@ -36,6 +41,7 @@ func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) {
}
// combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能
// ⭐️支持采用不同的模型服务进行规划和断言
type combinedLLMService struct {
planner IPlanner // 提供规划功能
asserter IAsserter // 提供断言功能
@@ -58,18 +64,20 @@ const (
EnvModelName = "LLM_MODEL_NAME"
)
var EnvModelUse string
const (
defaultTimeout = 30 * time.Second
)
// GetOpenAIModelConfig get OpenAI config
func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
type ModelConfig struct {
*openai.ChatModelConfig
ModelType option.LLMServiceType
}
// GetModelConfig get OpenAI config
func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
if err := config.LoadEnv(); err != nil {
return nil, errors.Wrap(code.LoadEnvError, err.Error())
}
EnvModelUse = os.Getenv("LLM_MODEL_USE")
openaiBaseURL := os.Getenv(EnvOpenAIBaseURL)
if openaiBaseURL == "" {
@@ -103,7 +111,10 @@ func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) {
Str("timeout", defaultTimeout.String()).
Msg("get model config")
return modelConfig, nil
return &ModelConfig{
ChatModelConfig: modelConfig,
ModelType: modelType,
}, nil
}
// maskAPIKey masks the API key

View File

@@ -41,28 +41,23 @@ type AssertionResponse struct {
// Asserter handles assertion using different AI models
type Asserter struct {
ctx context.Context
modelConfig *ModelConfig
model model.ToolCallingChatModel
systemPrompt string
history ConversationHistory
}
// NewAsserter creates a new Asserter instance
func NewAsserter(ctx context.Context) (*Asserter, error) {
func NewAsserter(ctx context.Context, modelConfig *ModelConfig) (*Asserter, error) {
asserter := &Asserter{
ctx: ctx,
modelConfig: modelConfig,
systemPrompt: defaultAssertionPrompt,
}
config, err := GetOpenAIModelConfig()
if err != nil {
return nil, err
}
if strings.Contains(EnvModelUse, string(option.LLMServiceTypeUITARS)) {
if modelConfig.ModelType == option.LLMServiceTypeUITARS {
asserter.systemPrompt += "\n\n" + uiTarsAssertionResponseFormat
} else if strings.Contains(EnvModelUse, string(option.LLMServiceTypeQwenVL)) {
asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat
} else if strings.Contains(EnvModelUse, string(option.LLMServiceTypeGPT)) {
} else if modelConfig.ModelType == option.LLMServiceTypeGPT {
// define output format
type OutputFormat struct {
Thought string `json:"thought"`
@@ -71,11 +66,11 @@ func NewAsserter(ctx context.Context) (*Asserter, error) {
}
outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil)
if err != nil {
return nil, err
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
// set structured response format
// https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go
config.ResponseFormat = &openai2.ChatCompletionResponseFormat{
modelConfig.ChatModelConfig.ResponseFormat = &openai2.ChatCompletionResponseFormat{
Type: openai2.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{
Name: "assertion_result",
@@ -85,12 +80,13 @@ func NewAsserter(ctx context.Context) (*Asserter, error) {
},
}
} else {
return nil, fmt.Errorf("model type %s not supported for asserter", EnvModelUse)
asserter.systemPrompt += "\n\n" + defaultAssertionResponseJsonFormat
}
asserter.model, err = openai.NewChatModel(ctx, config)
var err error
asserter.model, err = openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
if err != nil {
return nil, err
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
return asserter, nil
@@ -142,7 +138,7 @@ Here is the assertion. Please tell whether it is truthy according to the screens
startTime := time.Now()
resp, err := a.model.Generate(a.ctx, a.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", EnvModelUse).Msg("call model service for assertion")
Str("model", string(a.modelConfig.ModelType)).Msg("call model service for assertion")
if err != nil {
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
}

View File

@@ -4,13 +4,16 @@ import (
"context"
"testing"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createAsserter(t *testing.T) *Asserter {
asserter, err := NewAsserter(context.Background())
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
require.NoError(t, err)
asserter, err := NewAsserter(context.Background(), modelConfig)
require.NoError(t, err)
return asserter
}

View File

@@ -2,7 +2,6 @@ package ai
import (
"context"
"fmt"
"time"
"github.com/cloudwego/eino-ext/components/model/openai"
@@ -33,26 +32,22 @@ type PlanningResult struct {
Error string `json:"error,omitempty"`
}
func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner, error) {
func NewPlanner(ctx context.Context, modelConfig *ModelConfig) (*Planner, error) {
planner := &Planner{
ctx: ctx,
modelType: modelType,
ctx: ctx,
modelConfig: modelConfig,
}
config, err := GetOpenAIModelConfig()
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI config: %w", err)
}
if modelType == option.LLMServiceTypeUITARS {
if modelConfig.ModelType == option.LLMServiceTypeUITARS {
planner.systemPrompt = uiTarsPlanningPrompt
} else {
planner.systemPrompt = defaultPlanningResponseJsonFormat
}
planner.model, err = openai.NewChatModel(ctx, config)
var err error
planner.model, err = openai.NewChatModel(ctx, modelConfig.ChatModelConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
return nil, errors.Wrap(code.LLMPrepareRequestError, err.Error())
}
return planner, nil
@@ -60,9 +55,9 @@ func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner,
type Planner struct {
ctx context.Context
modelConfig *ModelConfig
model model.ToolCallingChatModel
systemPrompt string
modelType option.LLMServiceType
history ConversationHistory
}
@@ -76,11 +71,10 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = ConversationHistory{
{
Role: schema.System,
Content: systemPrompt,
Content: p.systemPrompt + opts.UserInstruction,
},
}
}
@@ -92,7 +86,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
startTime := time.Now()
resp, err := p.model.Generate(p.ctx, p.history)
log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()).
Str("model", string(p.modelType)).Msg("call model service")
Str("model", string(p.modelConfig.ModelType)).Msg("call model service")
if err != nil {
return nil, errors.Wrap(code.LLMRequestServiceError, err.Error())
}
@@ -116,7 +110,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
var parseActions []ParsedAction
var err error
if p.modelType == option.LLMServiceTypeUITARS {
if p.modelConfig.ModelType == option.LLMServiceTypeUITARS {
// parse Thought/Action format from UI-TARS
parseActions, err = parseThoughtAction(msg.Content)
if err != nil {

View File

@@ -36,7 +36,10 @@ func TestVLMPlanning(t *testing.T) {
userInstruction += "\n\n请基于以上游戏规则给出下一步可点击的两个图标坐标"
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
require.NoError(t, err)
planner, err := NewPlanner(context.Background(), modelConfig)
require.NoError(t, err)
opts := &PlanningOptions{
@@ -106,7 +109,10 @@ func TestXHSPlanning(t *testing.T) {
userInstruction := "点击第二个帖子的作者头像"
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
require.NoError(t, err)
planner, err := NewPlanner(context.Background(), modelConfig)
require.NoError(t, err)
opts := &PlanningOptions{
@@ -176,7 +182,10 @@ func TestChatList(t *testing.T) {
userInstruction := "请结合图片的文字信息,请告诉我一共有多少个群聊,哪些群聊右下角有绿点"
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
require.NoError(t, err)
planner, err := NewPlanner(context.Background(), modelConfig)
require.NoError(t, err)
opts := &PlanningOptions{
@@ -207,7 +216,10 @@ func TestHandleSwitch(t *testing.T) {
userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关
// 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启
planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS)
modelConfig, err := GetModelConfig(option.LLMServiceTypeUITARS)
require.NoError(t, err)
planner, err := NewPlanner(context.Background(), modelConfig)
require.NoError(t, err)
testCases := []struct {