Skip to content

Commit

Permalink
coap-gateway: check if oauth2 errors are temporary
Browse files Browse the repository at this point in the history
jwt package doesn't propagate correctly error:
golang/oauth2#635
  • Loading branch information
jkralik committed Mar 14, 2023
1 parent eb7fdbb commit ef3df4c
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 26 deletions.
54 changes: 37 additions & 17 deletions coap-gateway/service/message/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"context"
"errors"
"strings"
"net/http"

"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
"github.com/plgd-dev/go-coap/v3/message/status"
"golang.org/x/oauth2"
grpcCodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
grpcStatus "google.golang.org/grpc/status"
)

func GetResponse(ctx context.Context, messagePool *pool.Pool, code codes.Code, token message.Token, contentFormat message.MediaType, payload []byte) (*pool.Message, func()) {
Expand All @@ -26,7 +28,11 @@ func GetResponse(ctx context.Context, messagePool *pool.Pool, code codes.Code, t
}
}

// IsTempError returns true if error is temporary. Only certain errors are not considered as temporary errors.
func IsTempError(err error) bool {
if err == nil {
return false
}
var isTemporary interface {
Temporary() bool
}
Expand All @@ -39,29 +45,43 @@ func IsTempError(err error) bool {
if errors.As(err, &isTimeout) && isTimeout.Timeout() {
return true
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
var grpcStatus interface {
GRPCStatus() *status.Status
GRPCStatus() *grpcStatus.Status
}
if errors.As(err, &grpcStatus) {
switch grpcStatus.GRPCStatus().Code() {
case grpcCodes.Unavailable, grpcCodes.DeadlineExceeded, grpcCodes.Canceled, grpcCodes.Aborted:
return true
case grpcCodes.PermissionDenied,
grpcCodes.Unauthenticated,
grpcCodes.NotFound,
grpcCodes.AlreadyExists,
grpcCodes.InvalidArgument:
return false
}
return true
}
switch {
// TODO: We could optimize this by using error.Is to avoid string comparison.
case strings.Contains(err.Error(), "connect: connection refused"),
strings.Contains(err.Error(), "i/o timeout"),
strings.Contains(err.Error(), "TLS handshake timeout"),
strings.Contains(err.Error(), `http2:`), // any error at http2 protocol is considered as temporary error
strings.Contains(err.Error(), `write: broken pipe`),
strings.Contains(err.Error(), `request canceled while waiting for connection`),
strings.Contains(err.Error(), `authentication handshake failed`),
strings.Contains(err.Error(), context.DeadlineExceeded.Error()),
strings.Contains(err.Error(), context.Canceled.Error()):
oauth2Err := &oauth2.RetrieveError{}
if errors.As(err, &oauth2Err) {
if oauth2Err.Response != nil {
switch oauth2Err.Response.StatusCode {
case
http.StatusBadRequest,
http.StatusConflict,
http.StatusNotFound,
http.StatusForbidden,
http.StatusUnauthorized:
return false
}
}
return true
}
return false
if _, ok := status.FromError(err); ok {
// coap status code is not temporary
return false
}
return true
}

func GetErrorResponse(ctx context.Context, messagePool *pool.Pool, code codes.Code, token message.Token, err error) (*pool.Message, func()) {
Expand Down
110 changes: 110 additions & 0 deletions coap-gateway/service/message/response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package message

import (
"context"
"fmt"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type testError struct {
isTemporary bool
isTimeout bool
grpcStatus status.Status
oauthErr *oauth2.RetrieveError
}

type oauth2RetrieveError struct {
oauthErr *oauth2.RetrieveError
}

func (e *oauth2RetrieveError) Unwrap() error {
return e.oauthErr
}

func (e *oauth2RetrieveError) Error() string {
return "test oauth error"
}

func (e *testError) Error() string {
return "test error"
}

func (e *testError) Temporary() bool {
return e.isTemporary
}

func (e *testError) Timeout() bool {
return e.isTimeout
}

func (e *testError) GRPCStatus() *status.Status {
return &e.grpcStatus
}

func (e *testError) Unwrap() error {
return e.oauthErr
}

func TestIsTempError(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil",
args: args{},
want: false,
},
{
name: "not temporary",
args: args{err: fmt.Errorf("err: %w", &testError{grpcStatus: *status.Convert(status.Error(codes.Unauthenticated, "unauthenticated"))})},
want: false,
},
{
name: "any error as temporary",
args: args{err: fmt.Errorf("any error")},
want: true,
},
{
name: "temporary",
args: args{err: fmt.Errorf("err: %w", &testError{isTemporary: true})},
want: true,
},
{
name: "timeout",
args: args{err: fmt.Errorf("err: %w", &testError{isTimeout: true})},
want: true,
},
{
name: "grpcTemporary",
args: args{err: fmt.Errorf("err: %w", &testError{grpcStatus: *status.FromContextError(context.DeadlineExceeded)})},
want: true,
},
{
name: "oauth2Temporary",
args: args{err: fmt.Errorf("err: %w", &oauth2RetrieveError{oauthErr: &oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}})},
want: true,
},
{
name: "oauth2NotTemporary",
args: args{err: fmt.Errorf("err: %w", &oauth2RetrieveError{oauthErr: &oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusUnauthorized}}})},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsTempError(tt.args.err)
assert.Equal(t, tt.want, got)
})
}
}
8 changes: 4 additions & 4 deletions coap-gateway/service/signIn.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ func signInPostHandler(req *mux.Message, client *session, signIn CoapSignInReq)
expiresIn := validUntilToExpiresIn(validUntil)
accept, out, err := getSignInContent(expiresIn, req.Options())
if err != nil {
return nil, statusErrorf(coapCodes.InternalServerError, errFmtSignIn, err)
return nil, statusErrorf(coapCodes.ServiceUnavailable, errFmtSignIn, err)
}

updateDeviceMetadataResp, err := client.updateBySignInData(ctx, upd, deviceID, signIn.UserID)
if err != nil {
return nil, statusErrorf(coapCodes.InternalServerError, errFmtSignIn, err)
return nil, statusErrorf(coapCodes.ServiceUnavailable, errFmtSignIn, err)
}

setExpirationClientCache(client.server.expirationClientCache, deviceID, client, validUntil)
Expand All @@ -329,7 +329,7 @@ func signInPostHandler(req *mux.Message, client *session, signIn CoapSignInReq)
// try to register observations to the device at the cloud.
setNewDeviceObserver(x.ctx, x.client, x.deviceID, x.resetObservationType, x.twinEnabled)
}); err != nil {
return nil, statusErrorf(coapCodes.InternalServerError, errFmtSignIn, fmt.Errorf("failed to register device observer: %w", err))
return nil, statusErrorf(coapCodes.ServiceUnavailable, errFmtSignIn, fmt.Errorf("failed to register device observer: %w", err))
}

return client.createResponse(coapCodes.Changed, req.Token(), accept, out), nil
Expand Down Expand Up @@ -367,7 +367,7 @@ func signOutPostHandler(req *mux.Message, client *session, signOut CoapSignInReq
if signOut.DeviceID == "" || signOut.UserID == "" || signOut.AccessToken == "" {
authCurrentCtx, err := client.GetAuthorizationContext()
if err != nil {
return nil, statusErrorf(coapCodes.InternalServerError, errFmtSignOut, err)
return nil, statusErrorf(coapCodes.BadRequest, errFmtSignOut, err)
}
signOut = signOut.updateOAUthRequestIfEmpty(authCurrentCtx.DeviceID, authCurrentCtx.UserID, authCurrentCtx.AccessToken)
}
Expand Down
2 changes: 1 addition & 1 deletion coap-gateway/service/signOff.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func signOffHandler(req *mux.Message, client *session) (*pool.Message, error) {
return nil, statusErrorf(coapconv.GrpcErr2CoapCode(err, coapconv.Delete), errFmtSignOff, err)
}
if len(respIS.GetDeviceIds()) != 1 {
return nil, statusErrorf(coapCodes.InternalServerError, errFmtSignOff, fmt.Errorf("cannot remove device %v from user", deviceID))
return nil, statusErrorf(coapCodes.BadRequest, errFmtSignOff, fmt.Errorf("cannot remove device %v from user", deviceID))
}

client.CleanUp(true)
Expand Down
2 changes: 1 addition & 1 deletion coap-gateway/service/signUp.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func signUpPostHandler(req *mux.Message, client *session) (*pool.Message, error)

accept, out, err := getSignUpContent(token, owner, validUntil, req.Options())
if err != nil {
return nil, statusErrorf(coapCodes.InternalServerError, errFmtSignUP, err)
return nil, statusErrorf(coapCodes.BadRequest, errFmtSignUP, err)
}

return client.createResponse(coapCodes.Changed, req.Token(), accept, out), nil
Expand Down
4 changes: 3 additions & 1 deletion http-gateway/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ WORKDIR $GOPATH/src/github.com/plgd-dev/hub
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN go mod vendor
RUN ( cd /usr/local/go && patch -p1 < $GOPATH/src/github.com/plgd-dev/hub/tools/docker/patches/shrink_tls_conn.patch )
RUN ( cd ./vendor/golang.org/x/oauth2 && patch -p1 < $GOPATH/src/github.com/plgd-dev/hub/tools/docker/patches/golang_org_x_oauth2_propagate_error.patch )
WORKDIR $GOPATH/src/github.com/plgd-dev/hub/http-gateway
RUN CGO_ENABLED=0 go build -o /go/bin/http-gateway ./cmd/service
RUN CGO_ENABLED=0 go build -mod=vendor -o /go/bin/http-gateway ./cmd/service

FROM node:16 AS build-web
COPY http-gateway/web /web
Expand Down
4 changes: 3 additions & 1 deletion tools/docker/Dockerfile.go1.18
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ WORKDIR $GOPATH/src/github.com/plgd-dev/hub
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN go mod vendor
RUN ( cd /usr/local/go && patch -p1 < $GOPATH/src/github.com/plgd-dev/hub/tools/docker/patches/shrink_tls_conn.patch )
RUN ( cd ./vendor/golang.org/x/oauth2 && patch -p1 < $GOPATH/src/github.com/plgd-dev/hub/tools/docker/patches/golang_org_x_oauth2_propagate_error.patch )
WORKDIR $GOPATH/src/github.com/plgd-dev/hub/$DIRECTORY
RUN go build -o /go/bin/$NAME ./cmd/service
RUN go build -mod=vendor -o /go/bin/$NAME ./cmd/service

FROM alpine:3.17 AS service
ARG NAME
Expand Down
4 changes: 3 additions & 1 deletion tools/docker/Dockerfile.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ WORKDIR $GOPATH/src/github.com/plgd-dev/hub
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN go mod vendor
RUN ( cd /usr/local/go && patch -p1 < $GOPATH/src/github.com/plgd-dev/hub/tools/docker/patches/shrink_tls_conn.patch )
RUN ( cd ./vendor/golang.org/x/oauth2 && patch -p1 < $GOPATH/src/github.com/plgd-dev/hub/tools/docker/patches/golang_org_x_oauth2_propagate_error.patch )
WORKDIR $GOPATH/src/github.com/plgd-dev/hub/@DIRECTORY@
RUN CGO_ENABLED=0 go build -o /go/bin/@NAME@ ./cmd/service
RUN CGO_ENABLED=0 go build -mod=vendor -o /go/bin/@NAME@ ./cmd/service

FROM alpine:3.17 AS security-provider
RUN apk add -U --no-cache ca-certificates
Expand Down
26 changes: 26 additions & 0 deletions tools/docker/patches/golang_org_x_oauth2_propagate_error.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
diff --git a/internal/oauth2.go b/internal/oauth2.go
index c0ab196..cad4b1f 100644
--- a/internal/oauth2.go
+++ b/internal/oauth2.go
@@ -26,7 +26,7 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
if err != nil {
- return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err)
+ return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %w", err)
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
diff --git a/internal/token.go b/internal/token.go
index b4723fc..7b96171 100644
--- a/internal/token.go
+++ b/internal/token.go
@@ -234,7 +234,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
r.Body.Close()
if err != nil {
- return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
+ return nil, fmt.Errorf("oauth2: cannot fetch token: %w", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{

0 comments on commit ef3df4c

Please sign in to comment.