diff --git a/builtin/assertion.go b/builtin/assertion.go index 99782626..ca290aa8 100644 --- a/builtin/assertion.go +++ b/builtin/assertion.go @@ -1,6 +1,9 @@ package builtin import ( + "fmt" + "strings" + "github.com/stretchr/testify/assert" ) @@ -15,10 +18,36 @@ var Assertions = map[string]func(t assert.TestingT, expected interface{}, actual "contains": assert.Contains, "regex_match": assert.Regexp, // custom assertions + "startswith": StartsWith, // check if string starts with substring + "endswith": EndsWith, // check if string ends with substring "length_equals": EqualLength, "length_equal": EqualLength, // alias for length_equals } +func StartsWith(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if !assert.IsType(t, "string", actual, fmt.Sprintf("actual is %v", actual)) { + return false + } + if !assert.IsType(t, "string", expected, fmt.Sprintf("expected is %v", expected)) { + return false + } + actualString := actual.(string) + expectedString := expected.(string) + return assert.True(t, strings.HasPrefix(actualString, expectedString), msgAndArgs...) +} + +func EndsWith(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if !assert.IsType(t, "string", actual, fmt.Sprintf("actual is %v", actual)) { + return false + } + if !assert.IsType(t, "string", expected, fmt.Sprintf("expected is %v", expected)) { + return false + } + actualString := actual.(string) + expectedString := expected.(string) + return assert.True(t, strings.HasSuffix(actualString, expectedString), msgAndArgs...) +} + func EqualLength(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { return assert.Len(t, actual, expected.(int), msgAndArgs...) } diff --git a/builtin/assertion_test.go b/builtin/assertion_test.go new file mode 100644 index 00000000..328839be --- /dev/null +++ b/builtin/assertion_test.go @@ -0,0 +1,43 @@ +package builtin + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStartsWith(t *testing.T) { + testData := []struct { + raw string + expected string + }{ + {"", ""}, + {"a", "a"}, + {"abc", "a"}, + {"abc", "ab"}, + } + + for _, data := range testData { + if !assert.True(t, StartsWith(t, data.expected, data.raw)) { + t.Fail() + } + } +} + +func TestEndsWith(t *testing.T) { + testData := []struct { + raw string + expected string + }{ + {"", ""}, + {"a", "a"}, + {"abc", "c"}, + {"abc", "bc"}, + } + + for _, data := range testData { + if !assert.True(t, EndsWith(t, data.expected, data.raw)) { + t.Fail() + } + } +} diff --git a/validate.go b/validate.go index 4244edb7..d325fe22 100644 --- a/validate.go +++ b/validate.go @@ -32,6 +32,28 @@ func (s *stepRequestValidation) AssertEqual(jmesPath string, expected interface{ return s } +func (s *stepRequestValidation) AssertStartsWith(jmesPath string, expected interface{}, msg string) *stepRequestValidation { + validator := TValidator{ + Check: jmesPath, + Assert: "startswith", + Expect: expected, + Message: msg, + } + s.step.Validators = append(s.step.Validators, validator) + return s +} + +func (s *stepRequestValidation) AssertEndsWith(jmesPath string, expected interface{}, msg string) *stepRequestValidation { + validator := TValidator{ + Check: jmesPath, + Assert: "endswith", + Expect: expected, + Message: msg, + } + s.step.Validators = append(s.step.Validators, validator) + return s +} + func (s *stepRequestValidation) AssertLengthEqual(jmesPath string, expected interface{}, msg string) *stepRequestValidation { validator := TValidator{ Check: jmesPath,