refactor: call function

This commit is contained in:
debugtalk
2022-01-06 18:01:23 +08:00
parent b21cb780c5
commit 3f1f9d3690
4 changed files with 142 additions and 139 deletions

125
parser.go
View File

@@ -9,14 +9,11 @@ import (
"plugin" "plugin"
"reflect" "reflect"
"regexp" "regexp"
"runtime"
"strings" "strings"
"github.com/maja42/goval" "github.com/maja42/goval"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/httprunner/hrp/internal/builtin"
) )
func newParser() *parser { func newParser() *parser {
@@ -28,40 +25,6 @@ type parser struct {
pluginLoader *plugin.Plugin pluginLoader *plugin.Plugin
} }
func (p *parser) loadPlugin(path string) error {
if runtime.GOOS == "windows" {
log.Warn().Msg("go plugin does not support windows")
return nil
}
if path == "" {
return nil
}
// check if loaded before
if p.pluginLoader != nil {
return nil
}
// locate plugin file
pluginPath, err := locatePlugin(path)
if err != nil {
// plugin not found
return nil
}
// load plugin
plugins, err := plugin.Open(pluginPath)
if err != nil {
log.Error().Err(err).Str("path", path).Msg("load go plugin failed")
return err
}
p.pluginLoader = plugins
log.Info().Str("path", path).Msg("load go plugin success")
return nil
}
// locatePlugin searches debugtalk.so upward recursively until current // locatePlugin searches debugtalk.so upward recursively until current
// working directory or system root dir. // working directory or system root dir.
func locatePlugin(startPath string) (string, error) { func locatePlugin(startPath string) (string, error) {
@@ -251,12 +214,7 @@ func (p *parser) parseString(raw string, variablesMapping map[string]interface{}
return raw, err return raw, err
} }
fn, err := getMappingFunction(funcName, p.pluginLoader) result, err := p.callFunc(funcName, parsedArgs.([]interface{})...)
if err != nil {
return raw, err
}
result, err := callFunc(fn, parsedArgs.([]interface{})...)
if err != nil { if err != nil {
log.Error().Str("funcName", funcName).Interface("arguments", arguments). log.Error().Str("funcName", funcName).Interface("arguments", arguments).
Err(err).Msg("call function failed") Err(err).Msg("call function failed")
@@ -342,87 +300,6 @@ func mergeVariables(variables, overriddenVariables map[string]interface{}) map[s
return mergedVariables return mergedVariables
} }
func getMappingFunction(funcName string, pluginLoader *plugin.Plugin) (reflect.Value, error) {
var fn reflect.Value
var err error
defer func() {
// check function type
if err == nil && fn.Kind() != reflect.Func {
// function not valid
err = fmt.Errorf("function %s is invalid", funcName)
return
}
}()
// get function from plugin loader
if pluginLoader != nil {
sym, err := pluginLoader.Lookup(funcName)
if err == nil {
fn = reflect.ValueOf(sym)
return fn, nil
}
}
// get builtin function
if function, ok := builtin.Functions[funcName]; ok {
fn = reflect.ValueOf(function)
return fn, nil
}
// function not found
return reflect.Value{}, fmt.Errorf("function %s is not found", funcName)
}
// callFunc call function with arguments
// only support return at most one result value
func callFunc(fn reflect.Value, arguments ...interface{}) (interface{}, error) {
if fn.Type().NumIn() != len(arguments) {
// function arguments not match
return nil, fmt.Errorf("function arguments number not match")
}
argumentsValue := make([]reflect.Value, len(arguments))
for index, argument := range arguments {
argumentValue := reflect.ValueOf(argument)
expectArgumentType := fn.Type().In(index)
actualArgumentType := reflect.TypeOf(argument)
// type match
if expectArgumentType == actualArgumentType {
argumentsValue[index] = argumentValue
continue
}
// type not match, check if convertible
if !actualArgumentType.ConvertibleTo(expectArgumentType) {
// function argument type not match and not convertible
err := fmt.Errorf("function argument %d's type is neither match nor convertible, expect %v, actual %v",
index, expectArgumentType, actualArgumentType)
return nil, err
}
// convert argument to expect type
argumentsValue[index] = argumentValue.Convert(expectArgumentType)
}
resultValues := fn.Call(argumentsValue)
if len(resultValues) > 1 {
// function should return at most one value
err := fmt.Errorf("function should return at most one value")
return nil, err
}
// no return value
if len(resultValues) == 0 {
return nil, nil
}
// return one value
// convert reflect.Value to interface{}
result := resultValues[0].Interface()
return result, nil
}
var eval = goval.NewEvaluator() var eval = goval.NewEvaluator()
// literalEval parse string to number if possible // literalEval parse string to number if possible

View File

@@ -378,17 +378,17 @@ func TestMergeVariables(t *testing.T) {
} }
func TestCallBuiltinFunction(t *testing.T) { func TestCallBuiltinFunction(t *testing.T) {
parser := newParser()
// call function without arguments // call function without arguments
f1, _ := getMappingFunction("get_timestamp", nil) _, err := parser.callFunc("get_timestamp")
_, err := callFunc(f1)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.Fail() t.Fail()
} }
// call function with one argument // call function with one argument
timeStart := time.Now() timeStart := time.Now()
f2, _ := getMappingFunction("sleep", nil) _, err = parser.callFunc("sleep", 1)
_, err = callFunc(f2, 1)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.Fail() t.Fail()
} }
@@ -397,8 +397,7 @@ func TestCallBuiltinFunction(t *testing.T) {
} }
// call function with one argument // call function with one argument
f3, _ := getMappingFunction("gen_random_string", nil) result, err := parser.callFunc("gen_random_string", 10)
result, err := callFunc(f3, 10)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.Fail() t.Fail()
} }
@@ -407,8 +406,7 @@ func TestCallBuiltinFunction(t *testing.T) {
} }
// call function with two argument // call function with two argument
f4, _ := getMappingFunction("max", nil) result, err = parser.callFunc("max", float64(10), 9.99)
result, err = callFunc(f4, float64(10), 9.99)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.Fail() t.Fail()
} }

132
plugin.go Normal file
View File

@@ -0,0 +1,132 @@
package hrp
import (
"fmt"
"plugin"
"reflect"
"runtime"
"github.com/rs/zerolog/log"
"github.com/httprunner/hrp/internal/builtin"
)
func (p *parser) loadPlugin(path string) error {
if runtime.GOOS == "windows" {
log.Warn().Msg("go plugin does not support windows")
return nil
}
if path == "" {
return nil
}
// check if loaded before
if p.pluginLoader != nil {
return nil
}
// locate plugin file
pluginPath, err := locatePlugin(path)
if err != nil {
// plugin not found
return nil
}
// load plugin
plugins, err := plugin.Open(pluginPath)
if err != nil {
log.Error().Err(err).Str("path", path).Msg("load go plugin failed")
return err
}
p.pluginLoader = plugins
log.Info().Str("path", path).Msg("load go plugin success")
return nil
}
func getMappingFunction(funcName string, pluginLoader *plugin.Plugin) (reflect.Value, error) {
var fn reflect.Value
var err error
defer func() {
// check function type
if err == nil && fn.Kind() != reflect.Func {
// function not valid
err = fmt.Errorf("function %s is invalid", funcName)
return
}
}()
// get function from plugin loader
if pluginLoader != nil {
sym, err := pluginLoader.Lookup(funcName)
if err == nil {
fn = reflect.ValueOf(sym)
return fn, nil
}
}
// get builtin function
if function, ok := builtin.Functions[funcName]; ok {
fn = reflect.ValueOf(function)
return fn, nil
}
// function not found
return reflect.Value{}, fmt.Errorf("function %s is not found", funcName)
}
// callFunc calls function with arguments
// only support return at most one result value
func (p *parser) callFunc(funcName string, arguments ...interface{}) (interface{}, error) {
fn, err := getMappingFunction(funcName, p.pluginLoader)
if err != nil {
return nil, err
}
if fn.Type().NumIn() != len(arguments) {
// function arguments not match
return nil, fmt.Errorf("function arguments number not match")
}
argumentsValue := make([]reflect.Value, len(arguments))
for index, argument := range arguments {
argumentValue := reflect.ValueOf(argument)
expectArgumentType := fn.Type().In(index)
actualArgumentType := reflect.TypeOf(argument)
// type match
if expectArgumentType == actualArgumentType {
argumentsValue[index] = argumentValue
continue
}
// type not match, check if convertible
if !actualArgumentType.ConvertibleTo(expectArgumentType) {
// function argument type not match and not convertible
err := fmt.Errorf("function argument %d's type is neither match nor convertible, expect %v, actual %v",
index, expectArgumentType, actualArgumentType)
return nil, err
}
// convert argument to expect type
argumentsValue[index] = argumentValue.Convert(expectArgumentType)
}
resultValues := fn.Call(argumentsValue)
if len(resultValues) > 1 {
// function should return at most one value
err := fmt.Errorf("function should return at most one value")
return nil, err
}
// no return value
if len(resultValues) == 0 {
return nil, nil
}
// return one value
// convert reflect.Value to interface{}
result := resultValues[0].Interface()
return result, nil
}

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"plugin"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -27,14 +26,11 @@ func TestMain(m *testing.M) {
} }
func TestCallPluginFunction(t *testing.T) { func TestCallPluginFunction(t *testing.T) {
pluginLoader, err := plugin.Open("examples/debugtalk.so") parser := newParser()
if err != nil { parser.loadPlugin("examples/debugtalk.so")
t.Fatalf(err.Error())
}
// call function without arguments // call function without arguments
f1, _ := getMappingFunction("Concatenate", pluginLoader) result, err := parser.callFunc("Concatenate", 1, "2", 3.14)
result, err := callFunc(f1, 1, "2", 3.14)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.Fail() t.Fail()
} }