Skip to content

Commit

Permalink
feat: add backward-compat alias types
Browse files Browse the repository at this point in the history
  • Loading branch information
costela committed Feb 22, 2023
1 parent b3d339a commit 20b02ba
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion cmd/jwt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func verifyToken() error {
}

// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) {
token, err := jwt.Parse(string(tokData), func(t *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
Expand Down
8 changes: 4 additions & 4 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand All @@ -103,7 +103,7 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

Expand Down Expand Up @@ -136,7 +136,7 @@ func (m MyCustomClaims) CustomValidation() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

Expand All @@ -154,7 +154,7 @@ func ExampleParse_errorChecking() {
// Token from another example. This token is expired
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand Down
2 changes: 1 addition & 1 deletion hmac_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func ExampleParse_hmac() {
// useful if you use multiple keys for your application. The standard is to use 'kid' in the
// head of the token to identify which key to use, but the parsed token (head and claims) is provided
// to the callback, providing flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
Expand Down
4 changes: 2 additions & 2 deletions http_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func Example_getTokenViaHTTP() {
tokenString := strings.TrimSpace(buf.String())

// Parse the token
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
Expand Down Expand Up @@ -191,7 +191,7 @@ func authHandler(w http.ResponseWriter, r *http.Request) {
// only accessible with a valid token
func restrictedHandler(w http.ResponseWriter, r *http.Request) {
// Get token from request
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
token, err := request.ParseFromRequestWithClaims(r, request.OAuth2Extractor, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
Expand Down
6 changes: 3 additions & 3 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewParserFor[T Claims](options ...ParserOption) *Parser[T] {
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
func (p *Parser[T]) Parse(tokenString string, keyFunc KeyfuncFor[T]) (*TokenFor[T], error) {
token, parts, err := p.ParseUnverified(tokenString)
if err != nil {
return token, err
Expand Down Expand Up @@ -134,13 +134,13 @@ func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], er
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser[T]) ParseUnverified(tokenString string) (token *Token[T], parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string) (token *TokenFor[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}

token = &Token[T]{Raw: tokenString}
token = &TokenFor[T]{Raw: tokenString}

// parse Header
var headerBytes []byte
Expand Down
20 changes: 10 additions & 10 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ const (
keyFuncNil
)

func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.Keyfunc[T] {
func getKeyFunc[T jwt.Claims](kind keyFuncKind) jwt.KeyfuncFor[T] {
switch kind {
case keyFuncDefault:
return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestDefaultKey, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return jwtTestDefaultKey, nil }
case keyFuncECDSA:
return func(t *jwt.Token[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return jwtTestEC256PublicKey, nil }
case keyFuncPadded:
return func(t *jwt.Token[T]) (interface{}, error) { return paddedKey, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return paddedKey, nil }
case keyFuncEmpty:
return func(t *jwt.Token[T]) (interface{}, error) { return nil, nil }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return nil, nil }
case keyFuncError:
return func(t *jwt.Token[T]) (interface{}, error) { return nil, errKeyFuncError }
return func(t *jwt.TokenFor[T]) (interface{}, error) { return nil, errKeyFuncError }
case keyFuncNil:
return nil
default:
Expand Down Expand Up @@ -371,8 +371,8 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string {

// cloneToken is necesssary to "forget" the type information back to a generic jwt.Claims.
// Assignment of parameterized types is currently (1.20) not supported.
func cloneToken[T jwt.Claims](tin *jwt.Token[T]) *jwt.Token[jwt.Claims] {
tout := &jwt.Token[jwt.Claims]{}
func cloneToken[T jwt.Claims](tin *jwt.TokenFor[T]) *jwt.TokenFor[jwt.Claims] {
tout := &jwt.TokenFor[jwt.Claims]{}
tout.Claims = tin.Claims
tout.Header = tin.Header
tout.Method = tin.Method
Expand All @@ -392,7 +392,7 @@ func TestParser_Parse(t *testing.T) {
}

// Parse the token
var token *jwt.Token[jwt.Claims]
var token *jwt.TokenFor[jwt.Claims]
var err error
var ve *jwt.ValidationError
switch data.claims.(type) {
Expand Down Expand Up @@ -485,7 +485,7 @@ func TestParser_ParseUnverified(t *testing.T) {
}

// Parse the token
var token *jwt.Token[jwt.Claims]
var token *jwt.TokenFor[jwt.Claims]
var err error
switch data.claims.(type) {
case jwt.MapClaims:
Expand Down
12 changes: 9 additions & 3 deletions request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import (
// the logic for extracting a token. Several useful implementations are provided.
//
// You can provide options to modify parsing behavior
func ParseFromRequest[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc[T], options ...ParseFromRequestOption[T]) (token *jwt.Token[T], err error) {
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...Option) (token *jwt.Token, err error) {
return ParseFromRequestWithClaims(req, extractor, keyFunc, options...)
}

func ParseFromRequestWithClaims[T jwt.Claims](req *http.Request, extractor Extractor, keyFunc jwt.KeyfuncFor[T], options ...OptionFor[T]) (token *jwt.TokenFor[T], err error) {
// Create basic parser struct
p := &fromRequestParser[T]{
req: req,
Expand Down Expand Up @@ -45,10 +49,12 @@ type fromRequestParser[T jwt.Claims] struct {
parser *jwt.Parser[T]
}

type ParseFromRequestOption[T jwt.Claims] func(*fromRequestParser[T])
type OptionFor[T jwt.Claims] func(*fromRequestParser[T])

type Option = OptionFor[jwt.MapClaims]

// WithParser parses using a custom parser
func WithParser[T jwt.Claims](parser *jwt.Parser[T]) ParseFromRequestOption[T] {
func WithParser[T jwt.Claims](parser *jwt.Parser[T]) OptionFor[T] {
return func(p *fromRequestParser[T]) {
p.parser = parser
}
Expand Down
2 changes: 1 addition & 1 deletion request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestParseRequest(t *testing.T) {
// load keys from disk
privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key")
publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub")
keyfunc := func(*jwt.Token[jwt.MapClaims]) (interface{}, error) {
keyfunc := func(*jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
return publicKey, nil
}

Expand Down
30 changes: 18 additions & 12 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ var DecodePaddingAllowed bool
// To use strict decoding, set this boolean to `true` prior to using this package.
var DecodeStrict bool

// Keyfunc will be used by the Parse methods as a callback function to supply
// KeyfuncFor[T] will be used by the Parse methods as a callback function to supply
// the key for verification. The function receives the parsed,
// but unverified Token. This allows you to use properties in the
// but unverified TokenFor[T]. This allows you to use properties in the
// Header of the token (such as `kid`) to identify which key to use.
type Keyfunc[T Claims] func(*Token[T]) (interface{}, error)
type KeyfuncFor[T Claims] func(*TokenFor[T]) (interface{}, error)

// Token represents a JWT Token. Different fields will be used depending on whether you're
// Keyfunc is an alias for KeyfuncFor[Claims], for backward compatibility.
type Keyfunc = KeyfuncFor[MapClaims]

// TokenFor represents a JWT TokenFor. Different fields will be used depending on whether you're
// creating or parsing/verifying a token.
type Token[T Claims] struct {
type TokenFor[T Claims] struct {
Raw string // The raw token. Populated when you Parse a token
Method SigningMethod // The signing method used or to be used
Header map[string]interface{} // The first segment of the token
Expand All @@ -36,14 +39,17 @@ type Token[T Claims] struct {
Valid bool // Is the token valid? Populated when you Parse/Verify a token
}

// Token is an alias for TokenFor[Claims], for backward compatibility.
type Token = TokenFor[MapClaims]

// New creates a new Token with the specified signing method and an empty map of claims.
func New(method SigningMethod) *Token[MapClaims] {
func New(method SigningMethod) *Token {
return NewWithClaims(method, MapClaims{})
}

// NewWithClaims creates a new Token with the specified signing method and claims.
func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] {
return &Token[T]{
func NewWithClaims[T Claims](method SigningMethod, claims T) *TokenFor[T] {
return &TokenFor[T]{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
Expand All @@ -55,7 +61,7 @@ func NewWithClaims[T Claims](method SigningMethod, claims T) *Token[T] {

// SignedString creates and returns a complete, signed JWT.
// The token is signed using the SigningMethod specified in the token.
func (t *Token[T]) SignedString(key interface{}) (string, error) {
func (t *TokenFor[T]) SignedString(key interface{}) (string, error) {
var sig, sstr string
var err error
if sstr, err = t.SigningString(); err != nil {
Expand All @@ -71,7 +77,7 @@ func (t *Token[T]) SignedString(key interface{}) (string, error) {
// most expensive part of the whole deal. Unless you
// need this for something special, just go straight for
// the SignedString.
func (t *Token[T]) SigningString() (string, error) {
func (t *TokenFor[T]) SigningString() (string, error) {
var err error
var jsonValue []byte

Expand All @@ -95,7 +101,7 @@ func (t *Token[T]) SigningString() (string, error) {
// validate the 'alg' claim in the token matches the expected algorithm.
// For more details about the importance of validating the 'alg' claim,
// see https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOption) (*Token[MapClaims], error) {
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
return NewParser(options...).Parse(tokenString, keyFunc)
}

Expand All @@ -104,7 +110,7 @@ func Parse(tokenString string, keyFunc Keyfunc[MapClaims], options ...ParserOpti
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func ParseWithClaims[T Claims](tokenString string, keyFunc Keyfunc[T], options ...ParserOption) (*Token[T], error) {
func ParseWithClaims[T Claims](tokenString string, keyFunc KeyfuncFor[T], options ...ParserOption) (*TokenFor[T], error) {
return NewParserFor[T](options...).Parse(tokenString, keyFunc)
}

Expand Down
4 changes: 2 additions & 2 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestToken_SigningString(t1 *testing.T) {
}
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &jwt.Token[jwt.Claims]{
t := &jwt.TokenFor[jwt.Claims]{
Raw: tt.fields.Raw,
Method: tt.fields.Method,
Header: tt.fields.Header,
Expand All @@ -61,7 +61,7 @@ func TestToken_SigningString(t1 *testing.T) {
}

func BenchmarkToken_SigningString(b *testing.B) {
t := &jwt.Token[jwt.Claims]{
t := &jwt.TokenFor[jwt.Claims]{
Method: jwt.SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
Expand Down

0 comments on commit 20b02ba

Please sign in to comment.