Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit f9b4ad1

Browse files
authoredFeb 28, 2020
Fix #71 Do + DoAndReturn signature change error msg (#395)
Update the error handling for Call.Do and Call.DoAndReturn in the case where the argument passed does not match expectations. * panic if the argument is not a function * panic if the number of input arguments do not match those expected by Call * panic if the types of the input arguments do not match those expected by Call Call.DoAndReturn has additional validations on the return signature * panic if the number of return arguments do not match those expected by Call * panic if the types of return arguments do not match those expected by Call
1 parent b48cb66 commit f9b4ad1

File tree

4 files changed

+1109
-3
lines changed

4 files changed

+1109
-3
lines changed
 

‎gomock/call.go

+32-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
"reflect"
2020
"strconv"
2121
"strings"
22+
23+
"github.com/golang/mock/gomock/internal/validate"
2224
)
2325

2426
// Call represents an expected call to a mock.
@@ -105,10 +107,24 @@ func (c *Call) MaxTimes(n int) *Call {
105107
// DoAndReturn declares the action to run when the call is matched.
106108
// The return values from this function are returned by the mocked function.
107109
// It takes an interface{} argument to support n-arity functions.
110+
// If the method signature of f is not compatible with the mocked function a
111+
// panic will be triggered. Both the arguments and return values of f are
112+
// validated.
108113
func (c *Call) DoAndReturn(f interface{}) *Call {
109-
// TODO: Check arity and types here, rather than dying badly elsewhere.
110114
v := reflect.ValueOf(f)
111115

116+
switch v.Kind() {
117+
case reflect.Func:
118+
mt := c.methodType
119+
120+
ft := v.Type()
121+
if err := validate.InputAndOutputSig(ft, mt); err != nil {
122+
panic(fmt.Sprintf("DoAndReturn: %s", err))
123+
}
124+
default:
125+
panic("DoAndReturn: argument must be a function")
126+
}
127+
112128
c.addAction(func(args []interface{}) []interface{} {
113129
vargs := make([]reflect.Value, len(args))
114130
ft := v.Type()
@@ -134,10 +150,24 @@ func (c *Call) DoAndReturn(f interface{}) *Call {
134150
// return values are ignored to retain backward compatibility. To use the
135151
// return values call DoAndReturn.
136152
// It takes an interface{} argument to support n-arity functions.
153+
// If the method signature of f is not compatible with the mocked function a
154+
// panic will be triggered. Only the arguments of f are validated; not the return
155+
// values.
137156
func (c *Call) Do(f interface{}) *Call {
138-
// TODO: Check arity and types here, rather than dying badly elsewhere.
139157
v := reflect.ValueOf(f)
140158

159+
switch v.Kind() {
160+
case reflect.Func:
161+
mt := c.methodType
162+
163+
ft := v.Type()
164+
if err := validate.InputSig(ft, mt); err != nil {
165+
panic(fmt.Sprintf("Do: %s", err))
166+
}
167+
default:
168+
panic("Do: argument must be a function")
169+
}
170+
141171
c.addAction(func(args []interface{}) []interface{} {
142172
vargs := make([]reflect.Value, len(args))
143173
ft := v.Type()

‎gomock/call_test.go

+875
Large diffs are not rendered by default.

‎gomock/controller_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,11 @@ func TestDo(t *testing.T) {
492492
doCalled := false
493493
var argument string
494494
ctrl.RecordCall(subject, "FooMethod", "argument").Do(
495-
func(arg string) {
495+
func(arg string) int {
496496
doCalled = true
497497
argument = arg
498+
499+
return 0
498500
})
499501
if doCalled {
500502
t.Error("Do() callback called too early.")

‎gomock/internal/validate/validate.go

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// Copyright 2020 Google Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package validate
16+
17+
import (
18+
"fmt"
19+
"reflect"
20+
)
21+
22+
// InputAndOutputSig compares the argument and return signatures of actualFunc
23+
// against expectedFunc. It returns an error unless everything matches.
24+
func InputAndOutputSig(actualFunc, expectedFunc reflect.Type) error {
25+
if err := InputSig(actualFunc, expectedFunc); err != nil {
26+
return err
27+
}
28+
29+
if err := outputSig(actualFunc, expectedFunc); err != nil {
30+
return err
31+
}
32+
33+
return nil
34+
}
35+
36+
// InputSig compares the argument signatures of actualFunc
37+
// against expectedFunc. It returns an error unless everything matches.
38+
func InputSig(actualFunc, expectedFunc reflect.Type) error {
39+
// check number of arguments and type of each argument
40+
if actualFunc.NumIn() != expectedFunc.NumIn() {
41+
return fmt.Errorf(
42+
"expected function to have %d arguments not %d",
43+
expectedFunc.NumIn(), actualFunc.NumIn())
44+
}
45+
46+
lastIdx := expectedFunc.NumIn()
47+
48+
// If the function has a variadic argument validate that one first so that
49+
// we aren't checking for it while we iterate over the other args
50+
if expectedFunc.IsVariadic() {
51+
if ok := variadicArg(lastIdx, actualFunc, expectedFunc); !ok {
52+
i := lastIdx - 1
53+
return fmt.Errorf(
54+
"expected function to have"+
55+
" arg of type %v at position %d"+
56+
" not type %v",
57+
expectedFunc.In(i), i, actualFunc.In(i),
58+
)
59+
}
60+
61+
lastIdx--
62+
}
63+
64+
for i := 0; i < lastIdx; i++ {
65+
expectedArg := expectedFunc.In(i)
66+
actualArg := actualFunc.In(i)
67+
68+
if err := arg(actualArg, expectedArg); err != nil {
69+
return fmt.Errorf("input argument at %d: %s", i, err)
70+
}
71+
}
72+
73+
return nil
74+
}
75+
76+
func outputSig(actualFunc, expectedFunc reflect.Type) error {
77+
// check number of return vals and type of each val
78+
if actualFunc.NumOut() != expectedFunc.NumOut() {
79+
return fmt.Errorf(
80+
"expected function to have %d return vals not %d",
81+
expectedFunc.NumOut(), actualFunc.NumOut())
82+
}
83+
84+
for i := 0; i < expectedFunc.NumOut(); i++ {
85+
expectedArg := expectedFunc.Out(i)
86+
actualArg := actualFunc.Out(i)
87+
88+
if err := arg(actualArg, expectedArg); err != nil {
89+
return fmt.Errorf("return argument at %d: %s", i, err)
90+
}
91+
}
92+
93+
return nil
94+
}
95+
96+
func variadicArg(lastIdx int, actualFunc, expectedFunc reflect.Type) bool {
97+
if actualFunc.In(lastIdx-1) != expectedFunc.In(lastIdx-1) {
98+
if actualFunc.In(lastIdx-1).Kind() != reflect.Slice {
99+
return false
100+
}
101+
102+
expectedArgT := expectedFunc.In(lastIdx - 1)
103+
expectedElem := expectedArgT.Elem()
104+
if expectedElem.Kind() != reflect.Interface {
105+
return false
106+
}
107+
108+
actualArgT := actualFunc.In(lastIdx - 1)
109+
actualElem := actualArgT.Elem()
110+
111+
if ok := actualElem.ConvertibleTo(expectedElem); !ok {
112+
return false
113+
}
114+
115+
}
116+
117+
return true
118+
}
119+
120+
func interfaceArg(actualArg, expectedArg reflect.Type) error {
121+
if !actualArg.ConvertibleTo(expectedArg) {
122+
return fmt.Errorf(
123+
"expected arg convertible to type %v not type %v",
124+
expectedArg, actualArg,
125+
)
126+
}
127+
128+
return nil
129+
}
130+
131+
func mapArg(actualArg, expectedArg reflect.Type) error {
132+
expectedKey := expectedArg.Key()
133+
actualKey := actualArg.Key()
134+
135+
switch expectedKey.Kind() {
136+
case reflect.Interface:
137+
if err := interfaceArg(actualKey, expectedKey); err != nil {
138+
return fmt.Errorf("map key: %s", err)
139+
}
140+
default:
141+
if actualKey != expectedKey {
142+
return fmt.Errorf("expected map key of type %v not type %v",
143+
expectedKey, actualKey)
144+
}
145+
}
146+
147+
expectedElem := expectedArg.Elem()
148+
actualElem := actualArg.Elem()
149+
150+
switch expectedElem.Kind() {
151+
case reflect.Interface:
152+
if err := interfaceArg(actualElem, expectedElem); err != nil {
153+
return fmt.Errorf("map element: %s", err)
154+
}
155+
default:
156+
if actualElem != expectedElem {
157+
return fmt.Errorf("expected map element of type %v not type %v",
158+
expectedElem, actualElem)
159+
}
160+
}
161+
162+
return nil
163+
}
164+
165+
func arg(actualArg, expectedArg reflect.Type) error {
166+
switch expectedArg.Kind() {
167+
// If the expected arg is an interface we only care if the actual arg is convertible
168+
// to that interface
169+
case reflect.Interface:
170+
if err := interfaceArg(actualArg, expectedArg); err != nil {
171+
return err
172+
}
173+
default:
174+
// If the expected arg is not an interface then first check to see if
175+
// the actual arg is even the same reflect.Kind
176+
if expectedArg.Kind() != actualArg.Kind() {
177+
return fmt.Errorf("expected arg of kind %v not %v",
178+
expectedArg.Kind(), actualArg.Kind())
179+
}
180+
181+
switch expectedArg.Kind() {
182+
// If the expected arg is a map then we need to handle the case where
183+
// the map key or element type is an interface
184+
case reflect.Map:
185+
if err := mapArg(actualArg, expectedArg); err != nil {
186+
return err
187+
}
188+
default:
189+
if actualArg != expectedArg {
190+
return fmt.Errorf(
191+
"Expected arg of type %v not type %v",
192+
expectedArg, actualArg,
193+
)
194+
}
195+
}
196+
}
197+
198+
return nil
199+
}

0 commit comments

Comments
 (0)
This repository has been archived.