Skip to content

Commit

Permalink
feat: custom json and base64 encoders for Token and Parser
Browse files Browse the repository at this point in the history
  • Loading branch information
dcalsky committed Apr 2, 2023
1 parent b88a60f commit 7c726ab
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 21 deletions.
13 changes: 13 additions & 0 deletions encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package jwt

// Base64Encoder is an interface that allows to implement custom Base64 encoding/decoding algorithms.
type Base64Encoder interface {
EncodeToString(src []byte) string
DecodeString(s string) ([]byte, error)
}

// JSONEncoder is an interface that allows to implement custom JSON encoding/decoding algorithms.
type JSONEncoder interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
26 changes: 26 additions & 0 deletions encoder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package jwt_test

import (
"encoding/base64"
"encoding/json"
)

type customJSONEncoder struct{}

func (s *customJSONEncoder) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}

func (s *customJSONEncoder) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}

type customBase64Encoder struct{}

func (s *customBase64Encoder) EncodeToString(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}

func (s *customBase64Encoder) DecodeString(data string) ([]byte, error) {
return base64.RawURLEncoding.DecodeString(data)
}
46 changes: 36 additions & 10 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@ type Parser struct {
// If populated, only these methods will be considered valid.
validMethods []string

// Use JSON Number format in JSON decoder.
// Use JSON Number format in JSON decoder. This field is disabled when using a custom json encoder.
useJSONNumber bool

// Skip claims validation during token parsing.
skipClaimsValidation bool

validator *validator

// This field is disabled when using a custom base64 encoder.
decodeStrict bool

// This field is disabled when using a custom base64 encoder.
decodePaddingAllowed bool

// Custom base64 encoder.
base64Encoder Base64Encoder

// Custom json encoder.
jsonEncoder JSONEncoder
}

// NewParser creates a new Parser with the specified options
Expand Down Expand Up @@ -135,7 +143,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
}
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
if p.jsonEncoder != nil {
err = p.jsonEncoder.Unmarshal(headerBytes, &token.Header)
} else {
err = json.Unmarshal(headerBytes, &token.Header)
}
if err != nil {
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
}

Expand All @@ -146,21 +159,30 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
if claimBytes, err = p.DecodeSegment(parts[1]); err != nil {
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.useJSONNumber {
dec.UseNumber()
}

// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
mapClaims, isMapClaims := token.Claims.(MapClaims)
if p.jsonEncoder != nil {
if isMapClaims {
err = p.jsonEncoder.Unmarshal(claimBytes, &mapClaims)
} else {
err = p.jsonEncoder.Unmarshal(claimBytes, &claims)
}
} else {
err = dec.Decode(&claims)
decoder := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.useJSONNumber {
decoder.UseNumber()
}
if isMapClaims {
err = decoder.Decode(&mapClaims)
} else {
err = decoder.Decode(&claims)
}
}
// Handle decode error
if err != nil {
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
}

// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
Expand All @@ -177,6 +199,10 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
// take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
if p.base64Encoder != nil {
return p.base64Encoder.DecodeString(seg)
}

encoding := base64.RawURLEncoding

if p.decodePaddingAllowed {
Expand Down
13 changes: 13 additions & 0 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,16 @@ func WithStrictDecoding() ParserOption {
p.decodeStrict = true
}
}

// WithJSONEncoder supports
func WithJSONEncoder(enc JSONEncoder) ParserOption {
return func(p *Parser) {
p.jsonEncoder = enc
}
}

func WithBase64Encoder(enc Base64Encoder) ParserOption {
return func(p *Parser) {
p.base64Encoder = enc
}
}

0 comments on commit 7c726ab

Please sign in to comment.