Skip to content

Commit

Permalink
Merge pull request #629 from dillonstreator/expecter-structs-with-rol…
Browse files Browse the repository at this point in the history
…led-variadics

Add support for rolled variadics with expecter structs
  • Loading branch information
LandonTClipp committed May 23, 2023
2 parents b964d01 + ae9feff commit d0c93f6
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 33 deletions.
5 changes: 5 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@ packages:
unroll-variadic: False
- mockname: RequesterVariadic
Expecter:
config:
with-expecter: True
configs:
- mockname: ExpecterAndRolledVariadic
unroll-variadic: False
RequesterReturnElided:

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -696,11 +696,6 @@ func (g *Generator) Generate(ctx context.Context) error {
}

for _, method := range g.iface.Methods() {
// It's probably possible, but not worth the trouble for prototype
if method.Signature.Variadic() && g.config.WithExpecter && !g.config.UnrollVariadic {
return fmt.Errorf("cannot generate a valid expecter for variadic method with unroll-variadic=false")
}

g.generateMethod(ctx, method)
}

Expand Down Expand Up @@ -938,6 +933,15 @@ func {{ .ConstructorName }}{{ .TypeConstraint }}(t {{ .ConstructorTestingInterfa
func (g *Generator) generateCalled(list *paramList) (preamble string, called string) {
namesLen := len(list.Names)
if namesLen == 0 || !list.Variadic || !g.config.UnrollVariadic {
if list.Variadic && !g.config.UnrollVariadic && g.config.WithExpecter {
variadicName := list.Names[namesLen-1]
tmpRet := resolveCollision(list.Names, "tmpRet")

preamble = fmt.Sprintf("\n\tvar " + tmpRet + " mock.Arguments\n\tif len(" + variadicName + ") > 0 {\n\t\t" + tmpRet + " = _m.Called(" + strings.Join(list.Names, ", ") + ")\n\t} else {\n\t\t" + tmpRet + " = _m.Called(" + strings.Join(list.Names[:len(list.Names)-1], ", ") + ")\n\t}\n\n\t")
called = tmpRet
return
}

called = "_m.Called(" + strings.Join(list.Names, ", ") + ")"
return
}
Expand Down

0 comments on commit d0c93f6

Please sign in to comment.