Skip to content

Commit

Permalink
Create errorAssertionFunction creators
Browse files Browse the repository at this point in the history
To clean up table driven test with often repeated functions we can add a
simple function creator that generates a errorAssertionFunction
  • Loading branch information
JERHAV committed Mar 7, 2024
1 parent bb548d0 commit 1b432fb
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
11 changes: 11 additions & 0 deletions assert/assertions.go
Expand Up @@ -48,6 +48,17 @@ type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
// Comparison is a custom function that returns true on success and false on failure
type Comparison func() (success bool)

// ErrorIsFor returns an [ErrorAssertionFunc] which tests if the error wraps target.
func ErrorIsFor(target error) ErrorAssertionFunc {
return func(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}

return ErrorIs(t, err, target, msgAndArgs...)
}
}

/*
Helper functions
*/
Expand Down
22 changes: 20 additions & 2 deletions assert/assertions_test.go
Expand Up @@ -2750,7 +2750,13 @@ func ExampleErrorAssertionFunc() {
t := &testing.T{} // provided by test

dumbParseNum := func(input string, v interface{}) error {
return json.Unmarshal([]byte(input), v)

err := json.Unmarshal([]byte(input), v)
if err != nil {
return testingError{"could not Unmarshal " + input}
}

return nil
}

tests := []struct {
Expand All @@ -2760,8 +2766,8 @@ func ExampleErrorAssertionFunc() {
}{
{"1.2 is number", "1.2", NoError},
{"1.2.3 not number", "1.2.3", Error},
{"true is not number", "true", Error},
{"3 is number", "3", NoError},
{"3% is not a valid number", "3%", ErrorIsFor(testingError{"could not Unmarshal 3%"})},
}

for _, tt := range tests {
Expand All @@ -2772,14 +2778,26 @@ func ExampleErrorAssertionFunc() {
}
}

type testingError struct {
extraInfo string
}

func (t testingError) Error() string {
return t.extraInfo
}

func TestErrorAssertionFunc(t *testing.T) {
var testError = errors.New("test error")
tests := []struct {
name string
err error
assertion ErrorAssertionFunc
}{
{"noError", nil, NoError},
{"error", errors.New("whoops"), Error},
{"errorIs", testError, ErrorIsFor(testError)},
{"wrappedErrorIs", fmt.Errorf("This wrapped error: %w", testError),
ErrorIsFor(testError)},
}

for _, tt := range tests {
Expand Down
11 changes: 11 additions & 0 deletions require/requirements.go
Expand Up @@ -26,4 +26,15 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{})
// for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{})

// ErrorIsFunc returns an [ErrorAssertionFunc] which tests if the error wraps target.
func ErrorIsFor(expectedError error) ErrorAssertionFunc {
return func(t TestingT, err error, msgsAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}

ErrorIs(t, err, expectedError, msgsAndArgs...)
}
}

//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=require -template=require.go.tmpl -include-format-funcs"
5 changes: 5 additions & 0 deletions require/requirements_test.go
Expand Up @@ -3,6 +3,7 @@ package require
import (
"encoding/json"
"errors"
"fmt"
"testing"
"time"
)
Expand Down Expand Up @@ -666,13 +667,17 @@ func ExampleErrorAssertionFunc() {
}

func TestErrorAssertionFunc(t *testing.T) {
var testError = errors.New("test error")
tests := []struct {
name string
err error
assertion ErrorAssertionFunc
}{
{"noError", nil, NoError},
{"error", errors.New("whoops"), Error},
{"errorIs", testError, ErrorIsFor(testError)},
{"wrappedErrorIs", fmt.Errorf("This wrapped error: %w", testError),
ErrorIsFor(testError)},
}

for _, tt := range tests {
Expand Down

0 comments on commit 1b432fb

Please sign in to comment.