From c4261b06d89cb1592ef5fed68fb3a6be32b422bc Mon Sep 17 00:00:00 2001 From: Jay Petacat Date: Mon, 6 Jan 2020 18:31:55 -0500 Subject: [PATCH] Support Go 1.13 error chains in `Cause` Imagine module A imports module B and both use `pkg/errors`. A uses `errors.Cause` to inspect wrapped errors returned from B. As-is, B cannot migrate from `errors.Wrap` to `fmt.Errorf("%w", err)` because that would break `errors.Cause` calls in A. With this change merged, `errors.Cause` becomes forwards-compatible with Go 1.13 error chains. Module B will be free to switch to `fmt.Errorf("%w", err)` and that will not break module A (so long as the top-level project pulls in the newer version of `pkg/errors`). --- cause.go | 29 +++++++++++++++++++++++++++++ errors.go | 26 -------------------------- go113.go | 33 +++++++++++++++++++++++++++++++++ go113_test.go | 24 +++++++++++++++++++++++- 4 files changed, 85 insertions(+), 27 deletions(-) create mode 100644 cause.go diff --git a/cause.go b/cause.go new file mode 100644 index 0000000..566f88b --- /dev/null +++ b/cause.go @@ -0,0 +1,29 @@ +// +build !go1.13 + +package errors + +// Cause recursively unwraps an error chain and returns the underlying cause of +// the error, if possible. An error value has a cause if it implements the +// following interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/errors.go b/errors.go index 161aea2..a9840ec 100644 --- a/errors.go +++ b/errors.go @@ -260,29 +260,3 @@ func (w *withMessage) Format(s fmt.State, verb rune) { io.WriteString(s, w.Error()) } } - -// Cause returns the underlying cause of the error, if possible. -// An error value has a cause if it implements the following -// interface: -// -// type causer interface { -// Cause() error -// } -// -// If the error does not implement Cause, the original error will -// be returned. If the error is nil, nil will be returned without further -// investigation. -func Cause(err error) error { - type causer interface { - Cause() error - } - - for err != nil { - cause, ok := err.(causer) - if !ok { - break - } - err = cause.Cause() - } - return err -} diff --git a/go113.go b/go113.go index be0d10d..ed0dc7a 100644 --- a/go113.go +++ b/go113.go @@ -36,3 +36,36 @@ func As(err error, target interface{}) bool { return stderrors.As(err, target) } func Unwrap(err error) error { return stderrors.Unwrap(err) } + +// Cause recursively unwraps an error chain and returns the underlying cause of +// the error, if possible. There are two ways that an error value may provide a +// cause. First, the error may implement the following interface: +// +// type causer interface { +// Cause() error +// } +// +// Second, the error may return a non-nil value when passed as an argument to +// the Unwrap function. This makes Cause forwards-compatible with Go 1.13 error +// chains. +// +// If an error value satisfies both methods of unwrapping, Cause will use the +// causer interface. +// +// If the error is nil, nil will be returned without further investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + if cause, ok := err.(causer); ok { + err = cause.Cause() + } else if unwrapped := Unwrap(err); unwrapped != nil { + err = unwrapped + } else { + break + } + } + return err +} diff --git a/go113_test.go b/go113_test.go index 4ea37e6..7da3788 100644 --- a/go113_test.go +++ b/go113_test.go @@ -9,7 +9,29 @@ import ( "testing" ) -func TestErrorChainCompat(t *testing.T) { +func TestCauseErrorChainCompat(t *testing.T) { + err := stderrors.New("the cause!") + + // Wrap error using the standard library + wrapped := fmt.Errorf("wrapped with stdlib: %w", err) + if Cause(wrapped) != err { + t.Errorf("Cause does not support Go 1.13 error chains") + } + + // Wrap in another layer using pkg/errors + wrapped = WithMessage(wrapped, "wrapped with pkg/errors") + if Cause(wrapped) != err { + t.Errorf("Cause does not support Go 1.13 error chains") + } + + // Wrap in another layer using the standard library + wrapped = fmt.Errorf("wrapped with stdlib: %w", wrapped) + if Cause(wrapped) != err { + t.Errorf("Cause does not support Go 1.13 error chains") + } +} + +func TestWrapErrorChainCompat(t *testing.T) { err := stderrors.New("error that gets wrapped") wrapped := Wrap(err, "wrapped up") if !stderrors.Is(wrapped, err) {