diff --git a/internal/version/VERSION b/internal/version/VERSION index 77d42ce7..41dde217 100644 --- a/internal/version/VERSION +++ b/internal/version/VERSION @@ -1 +1 @@ -v5.0.0-beta-2504292314 +v5.0.0-beta-2504301407 diff --git a/uixt/ai/ai.go b/uixt/ai/ai.go index 064a9f37..75f6cf56 100644 --- a/uixt/ai/ai.go +++ b/uixt/ai/ai.go @@ -3,92 +3,36 @@ package ai import ( "context" "os" + "time" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/httprunner/httprunner/v5/code" "github.com/httprunner/httprunner/v5/internal/config" + "github.com/httprunner/httprunner/v5/uixt/option" "github.com/pkg/errors" "github.com/rs/zerolog/log" ) -func NewAIService(opts ...AIServiceOption) *AIServices { - services := &AIServices{} - for _, option := range opts { - option(services) - } - return services -} - -type AIServices struct { - ICVService - ILLMService -} - -type AIServiceOption func(*AIServices) - -type CVServiceType string - -const ( - CVServiceTypeVEDEM CVServiceType = "vedem" - CVServiceTypeOpenCV CVServiceType = "opencv" -) - -func WithCVService(service CVServiceType) AIServiceOption { - return func(opts *AIServices) { - if service == CVServiceTypeVEDEM { - var err error - opts.ICVService, err = NewVEDEMImageService() - if err != nil { - log.Error().Err(err).Msg("init vedem image service failed") - os.Exit(code.GetErrorCode(err)) - } - } - } -} - -type LLMServiceType string - -const ( - LLMServiceTypeUITARS LLMServiceType = "ui-tars" - LLMServiceTypeGPT LLMServiceType = "gpt" - LLMServiceTypeQwenVL LLMServiceType = "qwen-vl" -) - // ILLMService 定义了 LLM 服务接口,包括规划和断言功能 type ILLMService interface { Call(opts *PlanningOptions) (*PlanningResult, error) Assert(opts *AssertOptions) (*AssertionResponse, error) } -func WithLLMService(modelType LLMServiceType) AIServiceOption { - return func(opts *AIServices) { - // init planner - var planner IPlanner - var err error - switch modelType { - case LLMServiceTypeGPT: - // TODO: implement gpt-4o planner and asserter - planner, err = NewPlanner(context.Background()) - case LLMServiceTypeUITARS: - planner, err = NewUITarsPlanner(context.Background()) - } - if err != nil { - log.Error().Err(err).Msgf("init %s planner failed", modelType) - os.Exit(code.GetErrorCode(err)) - } - - // init asserter - asserter, err := NewAsserter(context.Background()) - if err != nil { - log.Error().Err(err).Msg("init asserter failed") - os.Exit(code.GetErrorCode(err)) - } - - opts.ILLMService = &combinedLLMService{ - planner: planner, - asserter: asserter, - } +func NewLLMService(modelType option.LLMServiceType) (ILLMService, error) { + planner, err := NewPlanner(context.Background(), modelType) + if err != nil { + return nil, err } + asserter, err := NewAsserter(context.Background()) + if err != nil { + return nil, err + } + + return &combinedLLMService{ + planner: planner, + asserter: asserter, + }, nil } // combinedLLMService 实现了 ILLMService 接口,组合了规划和断言功能 @@ -116,6 +60,10 @@ const ( var EnvModelUse string +const ( + defaultTimeout = 30 * time.Second +) + // GetOpenAIModelConfig get OpenAI config func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { if err := config.LoadEnv(); err != nil { @@ -157,3 +105,12 @@ func GetOpenAIModelConfig() (*openai.ChatModelConfig, error) { return modelConfig, nil } + +// maskAPIKey masks the API key +func maskAPIKey(key string) string { + if len(key) <= 8 { + return "******" + } + + return key[:4] + "******" + key[len(key)-4:] +} diff --git a/uixt/ai/planner.go b/uixt/ai/planner.go index 39435570..00b76e80 100644 --- a/uixt/ai/planner.go +++ b/uixt/ai/planner.go @@ -1,21 +1,18 @@ package ai import ( - "bytes" - "encoding/base64" + "context" "fmt" - "image" - "image/color" - "image/draw" - "image/png" - "os" - "strings" "time" + "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/uixt/option" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/pkg/errors" + "github.com/rs/zerolog/log" ) type IPlanner interface { @@ -36,30 +33,115 @@ type PlanningResult struct { Error string `json:"error,omitempty"` } -// ParsedAction represents a parsed action from the VLM response -type ParsedAction struct { - ActionType ActionType `json:"actionType"` - ActionInputs map[string]interface{} `json:"actionInputs"` - Thought string `json:"thought"` +func NewPlanner(ctx context.Context, modelType option.LLMServiceType) (*Planner, error) { + planner := &Planner{ + ctx: ctx, + modelType: modelType, + } + + config, err := GetOpenAIModelConfig() + if err != nil { + return nil, fmt.Errorf("failed to create OpenAI config: %w", err) + } + + if modelType == option.LLMServiceTypeUITARS { + planner.systemPrompt = uiTarsPlanningPrompt + } else { + planner.systemPrompt = defaultPlanningResponseJsonFormat + } + + planner.model, err = openai.NewChatModel(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err) + } + + return planner, nil } -type ActionType string +type Planner struct { + ctx context.Context + model model.ToolCallingChatModel + systemPrompt string + modelType option.LLMServiceType + history ConversationHistory +} -const ( - ActionTypeClick ActionType = "click" - ActionTypeTap ActionType = "tap" - ActionTypeDrag ActionType = "drag" - ActionTypeSwipe ActionType = "swipe" - ActionTypeWait ActionType = "wait" - ActionTypeFinished ActionType = "finished" - ActionTypeCallUser ActionType = "call_user" - ActionTypeType ActionType = "type" - ActionTypeScroll ActionType = "scroll" -) +// Call performs UI planning using Vision Language Model +func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { + // validate input parameters + if err := validatePlanningInput(opts); err != nil { + return nil, errors.Wrap(err, "validate planning parameters failed") + } -const ( - defaultTimeout = 30 * time.Second -) + // prepare prompt + if len(p.history) == 0 { + // add system message + systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction + p.history = ConversationHistory{ + { + Role: schema.System, + Content: systemPrompt, + }, + } + } + // append user image message + p.history.Append(opts.Message) + + // call model service, generate response + logRequest(p.history) + startTime := time.Now() + resp, err := p.model.Generate(p.ctx, p.history) + log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). + Str("model", string(p.modelType)).Msg("call model service") + if err != nil { + return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) + } + logResponse(resp) + + // parse result + result, err := p.parseResult(resp, opts.Size) + if err != nil { + return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error()) + } + + // append assistant message + p.history.Append(&schema.Message{ + Role: schema.Assistant, + Content: result.ActionSummary, + }) + + return result, nil +} + +func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) { + var parseActions []ParsedAction + var err error + if p.modelType == option.LLMServiceTypeUITARS { + // parse Thought/Action format from UI-TARS + parseActions, err = parseThoughtAction(msg.Content) + if err != nil { + return nil, err + } + } else { + // parse JSON format, from VLM like openai/gpt-4o + parseActions, err = parseJSON(msg.Content) + if err != nil { + return nil, err + } + } + + // process response + result, err := processVLMResponse(parseActions, size) + if err != nil { + return nil, errors.Wrap(err, "process VLM response failed") + } + + log.Info(). + Interface("summary", result.ActionSummary). + Interface("actions", result.NextActions). + Msg("get VLM planning result") + return result, nil +} func validatePlanningInput(opts *PlanningOptions) error { if opts.UserInstruction == "" { @@ -83,79 +165,3 @@ func validatePlanningInput(opts *PlanningOptions) error { return nil } - -// SavePositionImg saves an image with position markers -func SavePositionImg(params struct { - InputImgBase64 string - Rect struct { - X float64 - Y float64 - } - OutputPath string -}) error { - // 解码Base64图像 - imgData := params.InputImgBase64 - // 如果包含了数据URL前缀,去掉它 - if strings.HasPrefix(imgData, "data:image/") { - parts := strings.Split(imgData, ",") - if len(parts) > 1 { - imgData = parts[1] - } - } - - // 解码Base64 - unbased, err := base64.StdEncoding.DecodeString(imgData) - if err != nil { - return fmt.Errorf("无法解码Base64图像: %w", err) - } - - // 解码图像 - reader := bytes.NewReader(unbased) - img, _, err := image.Decode(reader) - if err != nil { - return fmt.Errorf("无法解码图像数据: %w", err) - } - - // 创建一个可以在其上绘制的图像 - bounds := img.Bounds() - rgba := image.NewRGBA(bounds) - draw.Draw(rgba, bounds, img, bounds.Min, draw.Src) - - // 在点击/拖动位置绘制标记 - markRadius := 10 - x, y := int(params.Rect.X), int(params.Rect.Y) - - // 绘制红色圆圈 - for i := -markRadius; i <= markRadius; i++ { - for j := -markRadius; j <= markRadius; j++ { - if i*i+j*j <= markRadius*markRadius { - if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y { - rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255}) - } - } - } - } - - // 保存图像 - outFile, err := os.Create(params.OutputPath) - if err != nil { - return fmt.Errorf("无法创建输出文件: %w", err) - } - defer outFile.Close() - - // 编码为PNG并保存 - if err := png.Encode(outFile, rgba); err != nil { - return fmt.Errorf("无法编码和保存图像: %w", err) - } - - return nil -} - -// maskAPIKey masks the API key -func maskAPIKey(key string) string { - if len(key) <= 8 { - return "******" - } - - return key[:4] + "******" + key[len(key)-4:] -} diff --git a/uixt/ai/planner_gpt.go b/uixt/ai/planner_gpt.go deleted file mode 100644 index a20de2d4..00000000 --- a/uixt/ai/planner_gpt.go +++ /dev/null @@ -1,172 +0,0 @@ -package ai - -import ( - "context" - "fmt" - _ "image/jpeg" - "strings" - "time" - - "github.com/cloudwego/eino-ext/components/model/openai" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/schema" - "github.com/pkg/errors" - "github.com/rs/zerolog/log" - - "github.com/httprunner/httprunner/v5/code" - "github.com/httprunner/httprunner/v5/internal/json" - "github.com/httprunner/httprunner/v5/uixt/types" -) - -func NewPlanner(ctx context.Context) (*Planner, error) { - config, err := GetOpenAIModelConfig() - if err != nil { - return nil, fmt.Errorf("failed to create OpenAI config: %w", err) - } - model, err := openai.NewChatModel(ctx, config) - if err != nil { - return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err) - } - return &Planner{ - ctx: ctx, - model: model, - modelType: LLMServiceTypeGPT, - systemPrompt: uiTarsPlanningPrompt, // TODO: change prompt with function calling - }, nil -} - -type Planner struct { - ctx context.Context - model model.ToolCallingChatModel - systemPrompt string - modelType LLMServiceType - history ConversationHistory -} - -// Call performs UI planning using Vision Language Model -func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { - // validate input parameters - if err := validatePlanningInput(opts); err != nil { - return nil, errors.Wrap(err, "validate planning parameters failed") - } - - // prepare prompt - if len(p.history) == 0 { - // add system message - systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction - p.history = ConversationHistory{ - { - Role: schema.System, - Content: systemPrompt, - }, - } - } - // append user image message - p.history.Append(opts.Message) - - // call model service, generate response - logRequest(p.history) - startTime := time.Now() - resp, err := p.model.Generate(p.ctx, p.history) - log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", string(p.modelType)).Msg("call model service") - if err != nil { - return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) - } - logResponse(resp) - - // parse result - result, err := p.parseResult(resp, opts.Size) - if err != nil { - return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error()) - } - - // append assistant message - p.history.Append(&schema.Message{ - Role: schema.Assistant, - Content: result.ActionSummary, - }) - - return result, nil -} - -func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) { - // parse JSON format, from VLM like openai/gpt-4o - parseActions, jsonErr := parseJSON(msg.Content) - if jsonErr != nil { - return nil, jsonErr - } - - // process response - result, err := processVLMResponse(parseActions, size) - if err != nil { - return nil, errors.Wrap(err, "process VLM response failed") - } - - log.Info(). - Interface("summary", result.ActionSummary). - Interface("actions", result.NextActions). - Msg("get VLM planning result") - return result, nil -} - -// parseJSON tries to parse the response as JSON format -func parseJSON(predictionText string) ([]ParsedAction, error) { - predictionText = strings.TrimSpace(predictionText) - if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") { - predictionText = strings.TrimPrefix(predictionText, "```json") - predictionText = strings.TrimSuffix(predictionText, "```") - } - predictionText = strings.TrimSpace(predictionText) - - var response PlanningResult - if err := json.Unmarshal([]byte(predictionText), &response); err != nil { - return nil, fmt.Errorf("failed to parse VLM response: %v", err) - } - - if response.Error != "" { - return nil, errors.New(response.Error) - } - - if len(response.NextActions) == 0 { - return nil, errors.New("no actions returned from VLM") - } - - // normalize actions - var normalizedActions []ParsedAction - for i := range response.NextActions { - // create a new variable, avoid implicit memory aliasing in for loop. - action := response.NextActions[i] - if err := normalizeAction(&action); err != nil { - return nil, errors.Wrap(err, "failed to normalize action") - } - normalizedActions = append(normalizedActions, action) - } - - return normalizedActions, nil -} - -// normalizeAction normalizes the coordinates in the action -func normalizeAction(action *ParsedAction) error { - switch action.ActionType { - case "click", "drag": - // handle click and drag action coordinates - if startBox, ok := action.ActionInputs["startBox"].(string); ok { - normalized, err := normalizeCoordinates(startBox) - if err != nil { - return fmt.Errorf("failed to normalize startBox: %w", err) - } - action.ActionInputs["startBox"] = normalized - } - - if endBox, ok := action.ActionInputs["endBox"].(string); ok { - normalized, err := normalizeCoordinates(endBox) - if err != nil { - return fmt.Errorf("failed to normalize endBox: %w", err) - } - action.ActionInputs["endBox"] = normalized - } - } - - return nil -} diff --git a/uixt/ai/planner_ui_tars.go b/uixt/ai/planner_parser.go similarity index 79% rename from uixt/ai/planner_ui_tars.go rename to uixt/ai/planner_parser.go index bbd22dfa..a86be3e3 100644 --- a/uixt/ai/planner_ui_tars.go +++ b/uixt/ai/planner_parser.go @@ -1,116 +1,38 @@ package ai import ( - "context" "fmt" "math" "regexp" "strconv" "strings" - "time" - "github.com/cloudwego/eino-ext/components/model/openai" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/schema" - "github.com/httprunner/httprunner/v5/code" "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/pkg/errors" "github.com/rs/zerolog/log" ) -func NewUITarsPlanner(ctx context.Context) (*UITarsPlanner, error) { - config, err := GetOpenAIModelConfig() - if err != nil { - return nil, err - } - chatModel, err := openai.NewChatModel(ctx, config) - if err != nil { - return nil, err - } - - return &UITarsPlanner{ - ctx: ctx, - model: chatModel, - modelType: LLMServiceTypeUITARS, - systemPrompt: uiTarsPlanningPrompt, - }, nil +// ParsedAction represents a parsed action from the VLM response +type ParsedAction struct { + ActionType ActionType `json:"actionType"` + ActionInputs map[string]interface{} `json:"actionInputs"` + Thought string `json:"thought"` } -type UITarsPlanner struct { - ctx context.Context - model model.ToolCallingChatModel - systemPrompt string - modelType LLMServiceType - history ConversationHistory -} +type ActionType string -// Call performs UI planning using Vision Language Model -func (p *UITarsPlanner) Call(opts *PlanningOptions) (*PlanningResult, error) { - // validate input parameters - if err := validatePlanningInput(opts); err != nil { - return nil, errors.Wrap(err, "validate planning parameters failed") - } - - // prepare prompt - if len(p.history) == 0 { - // add system message - systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction - p.history = ConversationHistory{ - { - Role: schema.System, - Content: systemPrompt, - }, - } - } - // append user image message - p.history.Append(opts.Message) - - // call model service, generate response - logRequest(p.history) - startTime := time.Now() - resp, err := p.model.Generate(p.ctx, p.history) - log.Info().Float64("elapsed(s)", time.Since(startTime).Seconds()). - Str("model", string(p.modelType)).Msg("call model service") - if err != nil { - return nil, errors.Wrap(code.LLMRequestServiceError, err.Error()) - } - logResponse(resp) - - // parse result - result, err := p.parseResult(resp, opts.Size) - if err != nil { - return nil, errors.Wrap(code.LLMParsePlanningResponseError, err.Error()) - } - - // append assistant message - p.history.Append(&schema.Message{ - Role: schema.Assistant, - Content: result.ActionSummary, - }) - - return result, nil -} - -func (p *UITarsPlanner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) { - // parse Thought/Action format from UI-TARS - parseActions, thoughtErr := parseThoughtAction(msg.Content) - if thoughtErr != nil { - return nil, thoughtErr - } - - // process response - result, err := processVLMResponse(parseActions, size) - if err != nil { - return nil, errors.Wrap(err, "process VLM response failed") - } - - log.Info(). - Interface("summary", result.ActionSummary). - Interface("actions", result.NextActions). - Msg("get VLM planning result") - return result, nil -} +const ( + ActionTypeClick ActionType = "click" + ActionTypeTap ActionType = "tap" + ActionTypeDrag ActionType = "drag" + ActionTypeSwipe ActionType = "swipe" + ActionTypeWait ActionType = "wait" + ActionTypeFinished ActionType = "finished" + ActionTypeCallUser ActionType = "call_user" + ActionTypeType ActionType = "type" + ActionTypeScroll ActionType = "scroll" +) // parseThoughtAction parses the Thought/Action format response func parseThoughtAction(predictionText string) ([]ParsedAction, error) { @@ -396,3 +318,64 @@ func validateTypeContent(action *ParsedAction) { log.Warn().Msg("type action missing content parameter, set to default") } } + +// parseJSON tries to parse the response as JSON format +func parseJSON(predictionText string) ([]ParsedAction, error) { + predictionText = strings.TrimSpace(predictionText) + if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") { + predictionText = strings.TrimPrefix(predictionText, "```json") + predictionText = strings.TrimSuffix(predictionText, "```") + } + predictionText = strings.TrimSpace(predictionText) + + var response PlanningResult + if err := json.Unmarshal([]byte(predictionText), &response); err != nil { + return nil, fmt.Errorf("failed to parse VLM response: %v", err) + } + + if response.Error != "" { + return nil, errors.New(response.Error) + } + + if len(response.NextActions) == 0 { + return nil, errors.New("no actions returned from VLM") + } + + // normalize actions + var normalizedActions []ParsedAction + for i := range response.NextActions { + // create a new variable, avoid implicit memory aliasing in for loop. + action := response.NextActions[i] + if err := normalizeAction(&action); err != nil { + return nil, errors.Wrap(err, "failed to normalize action") + } + normalizedActions = append(normalizedActions, action) + } + + return normalizedActions, nil +} + +// normalizeAction normalizes the coordinates in the action +func normalizeAction(action *ParsedAction) error { + switch action.ActionType { + case "click", "drag": + // handle click and drag action coordinates + if startBox, ok := action.ActionInputs["startBox"].(string); ok { + normalized, err := normalizeCoordinates(startBox) + if err != nil { + return fmt.Errorf("failed to normalize startBox: %w", err) + } + action.ActionInputs["startBox"] = normalized + } + + if endBox, ok := action.ActionInputs["endBox"].(string); ok { + normalized, err := normalizeCoordinates(endBox) + if err != nil { + return fmt.Errorf("failed to normalize endBox: %w", err) + } + action.ActionInputs["endBox"] = normalized + } + } + + return nil +} diff --git a/uixt/ai/planner_prompts.go b/uixt/ai/planner_prompts.go index d303c87f..e9c0f45b 100644 --- a/uixt/ai/planner_prompts.go +++ b/uixt/ai/planner_prompts.go @@ -27,3 +27,5 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par ## User Instruction ` + +const defaultPlanningResponseJsonFormat = `` diff --git a/uixt/ai/planner_test.go b/uixt/ai/planner_test.go index 68a988b3..9f3faddb 100644 --- a/uixt/ai/planner_test.go +++ b/uixt/ai/planner_test.go @@ -13,6 +13,7 @@ import ( "github.com/cloudwego/eino/schema" "github.com/httprunner/httprunner/v5/code" + "github.com/httprunner/httprunner/v5/uixt/option" "github.com/httprunner/httprunner/v5/uixt/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,7 +36,7 @@ func TestVLMPlanning(t *testing.T) { userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标" - planner, err := NewUITarsPlanner(context.Background()) + planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) require.NoError(t, err) opts := &PlanningOptions{ @@ -105,7 +106,7 @@ func TestXHSPlanning(t *testing.T) { userInstruction := "点击第二个帖子的作者头像" - planner, err := NewUITarsPlanner(context.Background()) + planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) require.NoError(t, err) opts := &PlanningOptions{ @@ -175,7 +176,7 @@ func TestChatList(t *testing.T) { userInstruction := "请结合图片的文字信息,请告诉我一共有多少个群聊,哪些群聊右下角有绿点" - planner, err := NewUITarsPlanner(context.Background()) + planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) require.NoError(t, err) opts := &PlanningOptions{ @@ -206,7 +207,7 @@ func TestHandleSwitch(t *testing.T) { userInstruction := "发送框下方的联网搜索开关是开启状态" // 点击开启联网搜索开关 // 检查发送框下方的联网搜索开关,蓝色为开启状态,灰色为关闭状态;若开关处于关闭状态,则点击进行开启 - planner, err := NewUITarsPlanner(context.Background()) + planner, err := NewPlanner(context.Background(), option.LLMServiceTypeUITARS) require.NoError(t, err) testCases := []struct { diff --git a/uixt/ai/session.go b/uixt/ai/session.go index 659ccc88..dc8c0b6d 100644 --- a/uixt/ai/session.go +++ b/uixt/ai/session.go @@ -1,6 +1,14 @@ package ai import ( + "bytes" + "encoding/base64" + "fmt" + "image" + "image/color" + "image/draw" + "image/png" + "os" "strings" "github.com/cloudwego/eino/schema" @@ -101,3 +109,70 @@ func logResponse(resp *schema.Message) { } logger.Msg("log response message") } + +// SavePositionImg saves an image with position markers +func SavePositionImg(params struct { + InputImgBase64 string + Rect struct { + X float64 + Y float64 + } + OutputPath string +}) error { + // 解码Base64图像 + imgData := params.InputImgBase64 + // 如果包含了数据URL前缀,去掉它 + if strings.HasPrefix(imgData, "data:image/") { + parts := strings.Split(imgData, ",") + if len(parts) > 1 { + imgData = parts[1] + } + } + + // 解码Base64 + unbased, err := base64.StdEncoding.DecodeString(imgData) + if err != nil { + return fmt.Errorf("无法解码Base64图像: %w", err) + } + + // 解码图像 + reader := bytes.NewReader(unbased) + img, _, err := image.Decode(reader) + if err != nil { + return fmt.Errorf("无法解码图像数据: %w", err) + } + + // 创建一个可以在其上绘制的图像 + bounds := img.Bounds() + rgba := image.NewRGBA(bounds) + draw.Draw(rgba, bounds, img, bounds.Min, draw.Src) + + // 在点击/拖动位置绘制标记 + markRadius := 10 + x, y := int(params.Rect.X), int(params.Rect.Y) + + // 绘制红色圆圈 + for i := -markRadius; i <= markRadius; i++ { + for j := -markRadius; j <= markRadius; j++ { + if i*i+j*j <= markRadius*markRadius { + if x+i >= 0 && x+i < bounds.Max.X && y+j >= 0 && y+j < bounds.Max.Y { + rgba.Set(x+i, y+j, color.RGBA{255, 0, 0, 255}) + } + } + } + } + + // 保存图像 + outFile, err := os.Create(params.OutputPath) + if err != nil { + return fmt.Errorf("无法创建输出文件: %w", err) + } + defer outFile.Close() + + // 编码为PNG并保存 + if err := png.Encode(outFile, rgba); err != nil { + return fmt.Errorf("无法编码和保存图像: %w", err) + } + + return nil +}