From 937b48236eca83b1ebf5efff6f94dc93509659ff Mon Sep 17 00:00:00 2001 From: debugtalk Date: Wed, 21 Sep 2022 16:03:58 +0800 Subject: [PATCH] change: make ocr as optional build tags --- hrp/internal/uixt/ext.go | 35 ++++++++++++------- hrp/internal/uixt/init.go | 2 +- hrp/internal/uixt/ocr_off.go | 10 ++++++ hrp/internal/uixt/{ocr.go => ocr_on.go} | 6 ++-- hrp/internal/uixt/ocr_test.go | 2 ++ .../uixt/{default.go => opencv_off.go} | 6 ++-- hrp/internal/uixt/{opencv.go => opencv_on.go} | 33 +++++++++++------ scripts/build.sh | 4 +-- 8 files changed, 65 insertions(+), 33 deletions(-) create mode 100644 hrp/internal/uixt/ocr_off.go rename hrp/internal/uixt/{ocr.go => ocr_on.go} (98%) rename hrp/internal/uixt/{default.go => opencv_off.go} (78%) rename hrp/internal/uixt/{opencv.go => opencv_on.go} (88%) diff --git a/hrp/internal/uixt/ext.go b/hrp/internal/uixt/ext.go index 02462183..40907c4f 100644 --- a/hrp/internal/uixt/ext.go +++ b/hrp/internal/uixt/ext.go @@ -21,24 +21,36 @@ import ( // TemplateMatchMode is the type of the template matching operation. type TemplateMatchMode int +type CVArgs struct { + scale float64 + matchMode TemplateMatchMode + threshold float64 +} + +type CVOption func(*CVArgs) + +func WithTemplateMatchMode(mode TemplateMatchMode) CVOption { + return func(args *CVArgs) { + args.matchMode = mode + } +} + +func WithThreshold(threshold float64) CVOption { + return func(args *CVArgs) { + args.threshold = threshold + } +} + type DriverExt struct { gwda.WebDriver windowSize gwda.Size frame *bytes.Buffer doneMjpegStream chan bool - // OpenCV - scale float64 - matchMode TemplateMatchMode - threshold float64 + CVArgs } -// Extend 获得扩展后的 Driver, -// 并指定匹配阀值, -// 获取当前设备的 Scale, -// 默认匹配模式为 TmCcoeffNormed, -// 默认关闭 OpenCV 匹配值计算后的输出 -func Extend(driver gwda.WebDriver, threshold float64, matchMode ...TemplateMatchMode) (dExt *DriverExt, err error) { +func extend(driver gwda.WebDriver) (dExt *DriverExt, err error) { dExt = &DriverExt{WebDriver: driver} dExt.doneMjpegStream = make(chan bool, 1) @@ -48,8 +60,7 @@ func Extend(driver gwda.WebDriver, threshold float64, matchMode ...TemplateMatch return nil, errors.Wrap(err, "failed to get windows size") } - err = dExt.extendOpenCV(threshold, matchMode...) - return dExt, err + return dExt, nil } func (dExt *DriverExt) ConnectMjpegStream(httpClient *http.Client) (err error) { diff --git a/hrp/internal/uixt/init.go b/hrp/internal/uixt/init.go index 5acd9bb7..7f45db42 100644 --- a/hrp/internal/uixt/init.go +++ b/hrp/internal/uixt/init.go @@ -45,7 +45,7 @@ func InitWDAClient(options ...gwda.DeviceOption) (*DriverExt, error) { if err != nil { return nil, errors.Wrap(err, "failed to init WDA driver") } - driverExt, err := Extend(driver, 0.95) + driverExt, err := Extend(driver) if err != nil { return nil, errors.Wrap(err, "failed to extend gwda.WebDriver") } diff --git a/hrp/internal/uixt/ocr_off.go b/hrp/internal/uixt/ocr_off.go new file mode 100644 index 00000000..7c3536b5 --- /dev/null +++ b/hrp/internal/uixt/ocr_off.go @@ -0,0 +1,10 @@ +//go:build !ocr + +package uixt + +import "github.com/rs/zerolog/log" + +func (dExt *DriverExt) FindTextByOCR(ocrText string) (x, y, width, height float64, err error) { + log.Fatal().Msg("OCR is not supported") + return +} diff --git a/hrp/internal/uixt/ocr.go b/hrp/internal/uixt/ocr_on.go similarity index 98% rename from hrp/internal/uixt/ocr.go rename to hrp/internal/uixt/ocr_on.go index d261e3de..2edf0d5a 100644 --- a/hrp/internal/uixt/ocr.go +++ b/hrp/internal/uixt/ocr_on.go @@ -1,16 +1,14 @@ +//go:build ocr + package uixt import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "image" - "io/ioutil" "mime/multipart" "net/http" "strings" - "time" ) var client = &http.Client{ diff --git a/hrp/internal/uixt/ocr_test.go b/hrp/internal/uixt/ocr_test.go index c334a1dd..928b39a2 100644 --- a/hrp/internal/uixt/ocr_test.go +++ b/hrp/internal/uixt/ocr_test.go @@ -1,3 +1,5 @@ +//go:build ocr + package uixt import ( diff --git a/hrp/internal/uixt/default.go b/hrp/internal/uixt/opencv_off.go similarity index 78% rename from hrp/internal/uixt/default.go rename to hrp/internal/uixt/opencv_off.go index 7506f686..7d56d5c5 100644 --- a/hrp/internal/uixt/default.go +++ b/hrp/internal/uixt/opencv_off.go @@ -5,12 +5,12 @@ package uixt import ( "image" + "github.com/electricbubble/gwda" "github.com/rs/zerolog/log" ) -func (dExt *DriverExt) extendOpenCV(threshold float64, matchMode ...TemplateMatchMode) (err error) { - log.Fatal().Msg("opencv is not supported") - return +func Extend(driver gwda.WebDriver, options ...CVOption) (dExt *DriverExt, err error) { + return extend(driver) } func (dExt *DriverExt) FindAllImageRect(search string) (rects []image.Rectangle, err error) { diff --git a/hrp/internal/uixt/opencv.go b/hrp/internal/uixt/opencv_on.go similarity index 88% rename from hrp/internal/uixt/opencv.go rename to hrp/internal/uixt/opencv_on.go index 504e6eef..6a184695 100644 --- a/hrp/internal/uixt/opencv.go +++ b/hrp/internal/uixt/opencv_on.go @@ -8,12 +8,15 @@ import ( "io/ioutil" "os" + "github.com/electricbubble/gwda" cvHelper "github.com/electricbubble/opencv-helper" ) const ( + // TmCcoeffNormed maps to TM_CCOEFF_NORMED + TmCcoeffNormed TemplateMatchMode = iota // TmSqdiff maps to TM_SQDIFF - TmSqdiff TemplateMatchMode = iota + TmSqdiff // TmSqdiffNormed maps to TM_SQDIFF_NORMED TmSqdiffNormed // TmCcorr maps to TM_CCORR @@ -22,8 +25,6 @@ const ( TmCcorrNormed // TmCcoeff maps to TM_CCOEFF TmCcoeff - // TmCcoeffNormed maps to TM_CCOEFF_NORMED - TmCcoeffNormed ) type DebugMode int @@ -42,18 +43,28 @@ const ( // 获取当前设备的 Scale, // 默认匹配模式为 TmCcoeffNormed, // 默认关闭 OpenCV 匹配值计算后的输出 -func (dExt *DriverExt) extendOpenCV(threshold float64, matchMode ...TemplateMatchMode) (err error) { - if dExt.scale, err = dExt.Scale(); err != nil { - return err +func Extend(driver gwda.WebDriver, options ...CVOption) (dExt *DriverExt, err error) { + dExt, err = extend(driver) + if err != nil { + return nil, err } - if len(matchMode) == 0 { - matchMode = []TemplateMatchMode{TmCcoeffNormed} + for _, option := range options { + option(&dExt.CVArgs) + } + + if dExt.scale, err = dExt.Scale(); err != nil { + return nil, err + } + + if dExt.threshold == 0 { + dExt.threshold = 0.95 // default threshold + } + if dExt.matchMode == 0 { + dExt.matchMode = TmCcoeffNormed // default match mode } - dExt.matchMode = matchMode[0] cvHelper.Debug(cvHelper.DebugMode(DmOff)) - dExt.threshold = threshold - return nil + return } func (dExt *DriverExt) Debug(dm DebugMode) { diff --git a/scripts/build.sh b/scripts/build.sh index 7a75d8ed..1c6cfd52 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -15,8 +15,8 @@ mkdir -p "output" bin_path="output/hrp" # build -# optional build tags: opencv -go build -ldflags '-s -w' -o "$bin_path" hrp/cmd/cli/main.go +# optional build tags: opencv ocr +go build -ldflags '-s -w' -tags ocr -o "$bin_path" hrp/cmd/cli/main.go # check output and version ls -lh "$bin_path"