Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mock: add support for mock.Anything in slices #1577

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
167 changes: 93 additions & 74 deletions mock/mock.go
Expand Up @@ -930,99 +930,118 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
}

for i := 0; i < maxArgCount; i++ {
var actual, expected interface{}
var actualFmt, expectedFmt string
var expected, actual interface{}
if len(args) <= i {
expected = "(Missing)"
} else {
expected = args[i]
}

if len(objects) <= i {
actual = "(Missing)"
actualFmt = "(Missing)"
} else {
actual = objects[i]
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
}

if len(args) <= i {
expected = "(Missing)"
expectedFmt = "(Missing)"
} else {
expected = args[i]
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
equal, elementOutput := compareElements(expected, actual, i, false)
output += elementOutput
if !equal {
differences++
}
}

if matcher, ok := expected.(argumentMatcher); ok {
var matches bool
func() {
defer func() {
if r := recover(); r != nil {
actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
}
}()
matches = matcher.Matches(actual)
if differences == 0 {
return "No differences.", differences
}

return output, differences
}

func compareElements(expected, actual interface{}, i int, isSlice bool) (bool, string) {
var expectedFmt, actualFmt string
if expected == "(Missing)" {
expectedFmt = expected.(string)
} else {
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
}
if actual == "(Missing)" {
actualFmt = actual.(string)
} else {
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
}
if matcher, ok := expected.(argumentMatcher); ok {
var matches bool
func() {
defer func() {
if r := recover(); r != nil {
actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
}
}()
if matches {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
matches = matcher.Matches(actual)
}()
if matches {
return true, fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt, matcher)
} else {
return false, fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt, matcher)
}
} else {
switch expected := expected.(type) {
case anythingOfTypeArgument:
if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt)
} else {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
return true, ""
}
} else {
switch expected := expected.(type) {
case anythingOfTypeArgument:
// type checking
if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
}
case *IsTypeArgument:
actualT := reflect.TypeOf(actual)
if actualT != expected.t {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
}
case *FunctionalOptionsArgument:
t := expected.value
case *IsTypeArgument:
actualT := reflect.TypeOf(actual)
if actualT != expected.t {
return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected.t, reflect.TypeOf(actual).Name(), actualFmt)
} else {
return true, ""
}
case *FunctionalOptionsArgument:
t := expected.value

var name string
tValue := reflect.ValueOf(t)
if tValue.Len() > 0 {
name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
}
var name string
tValue := reflect.ValueOf(t)
if tValue.Len() > 0 {
name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
}

tName := reflect.TypeOf(t).Name()
if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
tName := reflect.TypeOf(t).Name()
if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 {
return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, tName, reflect.TypeOf(actual).Name(), actualFmt)
} else {
if ef, af := assertOpts(t, actual); ef == "" && af == "" {
return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, tName, tName)
} else {
if ef, af := assertOpts(t, actual); ef == "" && af == "" {
// match
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName)
} else {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef)
}
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, af, ef)
}

default:
if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
// match
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
} else {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
}
case []interface{}:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new switch statement added for supporting slices.

Maybe I should declare a type other than []interface{} here? Not sure. Open to any feedback.

My thought is that since we don't check for slices at all presently, a "catch-all" for slices with []interface{} which then calls the existing (refactored) compareElements logic makes sense.

if ev, av := reflect.ValueOf(expected), reflect.ValueOf(actual); ev.Kind() == reflect.Slice && av.Kind() == reflect.Slice {
// Unroll slices to check for Anything / AnythingOFType
if ev.Len() != av.Len() {
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}
for e := 0; e < ev.Len(); e++ {
equal, _ := compareElements(ev.Index(e).Interface(), av.Index(e).Interface(), i, true)
if !equal {
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}
}
return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt)
}
default:
if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
// match
return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt)
} else {
// not match
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}
}

}

if differences == 0 {
return "No differences.", differences
}

return output, differences
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}

// Assert compares the arguments with the specified objects and fails if
Expand Down
30 changes: 30 additions & 0 deletions mock/mock_test.go
Expand Up @@ -2146,3 +2146,33 @@ type user interface {
type mockUser struct{ Mock }

func (m *mockUser) Use(c caller) { m.Called(c) }

func TestAnythingInSlices(t *testing.T) {
m := &TestExampleImplementation{}

m.On("TheExampleMethodVariadic", []interface{}{1, Anything, 3, Anything, 5}).Return(nil)
var err error

assert.NotPanics(t, func() {
err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5)
})

assert.NoError(t, err)
m.AssertExpectations(t)
m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{Anything, 2, Anything, 4, Anything})
}

func TestAnythingOfTypeInSlices(t *testing.T) {
m := &TestExampleImplementation{}

m.On("TheExampleMethodVariadic", []interface{}{1, AnythingOfType("int"), 3, AnythingOfType("int"), 5}).Return(nil)
var err error

assert.NotPanics(t, func() {
err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5)
})

assert.NoError(t, err)
m.AssertExpectations(t)
m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{AnythingOfType("int"), 2, AnythingOfType("int"), 4, AnythingOfType("int")})
}