diff --git a/core_dsl.go b/core_dsl.go index 6e20042c9..4ea63b84e 100644 --- a/core_dsl.go +++ b/core_dsl.go @@ -724,10 +724,10 @@ DeferCleanup can be called within any Setup or Subject node to register a cleanu DeferCleanup can be passed: 1. A function that takes no arguments and returns no values. -2. A function that returns an error (in which case it will assert that the returned error was nil, or it will fail the spec). -3. A function that takes a context.Context or SpecContext (and optionally returns an error). The resulting cleanup node is deemed interruptible and the passed-in context will be cancelled in the event of a timeout or interrupt. -4. A function that takes arguments (and optionally returns an error) followed by a list of arguments to pass to the function. -5. A function that takes SpecContext and a list of arguments (and optionally returns an error) followed by a list of arguments to pass to the function. +2. A function that returns multiple values. `DeferCleanup` will ignore all these return values except for the last one. If this last return value is a non-nil error `DeferCleanup` will fail the spec). +3. A function that takes a context.Context or SpecContext (and optionally returns multiple values). The resulting cleanup node is deemed interruptible and the passed-in context will be cancelled in the event of a timeout or interrupt. +4. A function that takes arguments (and optionally returns multiple values) followed by a list of arguments to pass to the function. +5. A function that takes SpecContext and a list of arguments (and optionally returns multiple values) followed by a list of arguments to pass to the function. For example: diff --git a/docs/index.md b/docs/index.md index 07163cc97..534a6c912 100644 --- a/docs/index.md +++ b/docs/index.md @@ -982,7 +982,7 @@ As you can see, `DeferCleanup()` can be called inside any setup or subject nodes `DeferCleanup` has a few more tricks up its sleeve. -As shown above `DeferCleanup` can be passed a function that takes no arguments and returns no value. You can also pass a function that returns a single value. `DeferCleanup` interprets this value as an error and fails the spec if the error is non-nil - a common go pattern. This allows us to rewrite our example as: +As shown above `DeferCleanup` can be passed a function that takes no arguments and returns no value. You can also pass a function that returns values. `DeferCleanup` ignores all these return value except for the last. If the last return value is a non-nil error - a common go pattern - `DeferCleanup` will fail the spec. This allows us to rewrite our example as: ```go Describe("Reporting book weight", func() { diff --git a/internal/internal_integration/cleanup_test.go b/internal/internal_integration/cleanup_test.go index c74a2e728..6b738f35c 100644 --- a/internal/internal_integration/cleanup_test.go +++ b/internal/internal_integration/cleanup_test.go @@ -110,6 +110,26 @@ var _ = Describe("Cleanup", func() { }) }) + Context("because of a returned error, for a multi-return function", func() { + BeforeEach(func() { + success, _ := RunFixture("cleanup failure", func() { + BeforeEach(rt.T("BE", C("C-BE"))) + It("A", rt.T("A", func() { + DeferCleanup(func() (string, error) { + rt.Run("C-A") + return "ok", fmt.Errorf("fail") + }) + })) + }) + Ω(success).Should(BeFalse()) + }) + + It("reports a failure", func() { + Ω(rt).Should(HaveTracked("BE", "A", "C-A", "C-BE")) + Ω(reporter.Did.Find("A")).Should(HaveFailed("DeferCleanup callback returned error: fail", FailureNodeType(types.NodeTypeCleanupAfterEach), types.FailureNodeAtTopLevel)) + }) + }) + Context("at the suite level", func() { BeforeEach(func() { success, _ := RunFixture("cleanup failure", func() { diff --git a/internal/node.go b/internal/node.go index 4f3d1c985..9eb835e9d 100644 --- a/internal/node.go +++ b/internal/node.go @@ -508,6 +508,8 @@ func extractSynchronizedBeforeSuiteAllProcsBody(arg interface{}) (func(SpecConte }, hasContext } +var errInterface = reflect.TypeOf((*error)(nil)).Elem() + func NewCleanupNode(deprecationTracker *types.DeprecationTracker, fail func(string, types.CodeLocation), args ...interface{}) (Node, []error) { decorations, remainingArgs := PartitionDecorations(args...) baseOffset := 2 @@ -530,7 +532,7 @@ func NewCleanupNode(deprecationTracker *types.DeprecationTracker, fail func(stri } callback := reflect.ValueOf(remainingArgs[0]) - if !(callback.Kind() == reflect.Func && callback.Type().NumOut() <= 1) { + if !(callback.Kind() == reflect.Func) { return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(cl)} } @@ -550,8 +552,12 @@ func NewCleanupNode(deprecationTracker *types.DeprecationTracker, fail func(stri } handleFailure := func(out []reflect.Value) { - if len(out) == 1 && !out[0].IsNil() { - fail(fmt.Sprintf("DeferCleanup callback returned error: %v", out[0]), cl) + if len(out) == 0 { + return + } + last := out[len(out)-1] + if last.Type().Implements(errInterface) && !last.IsNil() { + fail(fmt.Sprintf("DeferCleanup callback returned error: %v", last), cl) } } diff --git a/internal/node_test.go b/internal/node_test.go index cb7f33837..b7e0f2a7d 100644 --- a/internal/node_test.go +++ b/internal/node_test.go @@ -894,23 +894,28 @@ var _ = Describe("Node", func() { }) }) - Context("when passed a function that returns too many values", func() { - It("errors", func() { - node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() (int, error) { - return 0, nil + Context("when passed a function that does not return", func() { + It("creates a body that runs the function and never calls the fail handler", func() { + didRun := false + node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() { + didRun = true }) - Ω(node.IsZero()).Should(BeTrue()) - Ω(errs).Should(ConsistOf(types.GinkgoErrors.DeferCleanupInvalidFunction(cl))) + Ω(node.CodeLocation).Should(Equal(cl)) + Ω(node.NodeType).Should(Equal(types.NodeTypeCleanupInvalid)) + Ω(errs).Should(BeEmpty()) + + node.Body(internal.NewSpecContext(nil)) + Ω(didRun).Should(BeTrue()) Ω(capturedFailure).Should(BeZero()) Ω(capturedCL).Should(BeZero()) }) }) - - Context("when passed a function that does not return", func() { + Context("when passed a function that returns somethign that isn't an error", func() { It("creates a body that runs the function and never calls the fail handler", func() { didRun := false - node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() { + node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() (string, int) { didRun = true + return "not-an-error", 17 }) Ω(node.CodeLocation).Should(Equal(cl)) Ω(node.NodeType).Should(Equal(types.NodeTypeCleanupInvalid)) @@ -923,12 +928,12 @@ var _ = Describe("Node", func() { }) }) - Context("when passed a function that returns nil", func() { + Context("when passed a function that returns a nil error", func() { It("creates a body that runs the function and does not call the fail handler", func() { didRun := false - node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() error { + node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() (string, int, error) { didRun = true - return nil + return "not-an-error", 17, nil }) Ω(node.CodeLocation).Should(Equal(cl)) Ω(node.NodeType).Should(Equal(types.NodeTypeCleanupInvalid)) @@ -941,12 +946,12 @@ var _ = Describe("Node", func() { }) }) - Context("when passed a function that returns an error", func() { - It("creates a body that runs the function and does not call the fail handler", func() { + Context("when passed a function that returns an error for its final return value", func() { + It("creates a body that runs the function and calls the fail handler", func() { didRun := false - node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() error { + node, errs := internal.NewCleanupNode(dt, failFunc, cl, func() (string, int, error) { didRun = true - return fmt.Errorf("welp") + return "not-an-error", 17, fmt.Errorf("welp") }) Ω(node.CodeLocation).Should(Equal(cl)) Ω(node.NodeType).Should(Equal(types.NodeTypeCleanupInvalid))