Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding canonical Keyfunc functions for RSA, ECDSA, EdDSA and HMAC #275

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,10 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte
return nil, err
}
}

// ECDSAPublicKey represents a [Keyfunc] that returns the ECDSA key specified in
// key. Furthermore, it checks, whether the signing method matches
// [SigningMethodECDSA].
func ECDSAPublicKey(key *ecdsa.PublicKey) Keyfunc {
return secureKeyFunc(key, []string{"ES256", "ES384", "ES512"})
}
7 changes: 7 additions & 0 deletions ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,10 @@ func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]by

return sig, nil
}

// Ed25519PublicKey represents a [Keyfunc] that returns the Ed25519 key
// specified in key. Furthermore, it checks, whether the signing method matches
// [SigningMethodEdDSA].
func Ed25519PublicKey(key ed25519.PublicKey) Keyfunc {
return secureKeyFunc(key, []string{"EdDSA"})
}
23 changes: 11 additions & 12 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, jwt.PresharedKey([]byte("AllYourBase")))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
Expand All @@ -103,9 +101,11 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))
token, err := jwt.ParseWithClaims(
tokenString, &MyCustomClaims{},
jwt.PresharedKey([]byte("AllYourBase")),
jwt.WithLeeway(5*time.Second),
)

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
Expand Down Expand Up @@ -138,9 +138,10 @@ func (m MyCustomClaims) Validate() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))
token, err := jwt.ParseWithClaims(
tokenString, &MyCustomClaims{},
jwt.PresharedKey([]byte("AllYourBase")),
jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
Expand All @@ -156,9 +157,7 @@ func ExampleParse_errorChecking() {
// Token from another example. This token is expired
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
token, err := jwt.Parse(tokenString, jwt.PresharedKey([]byte("AllYourBase")))

if token.Valid {
fmt.Println("You look nice today")
Expand Down
5 changes: 5 additions & 0 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,8 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte,

return nil, ErrInvalidKeyType
}

// PresharedKey represents a [Keyfunc] that simply returns the key specified in the byte slice.
func PresharedKey(key []byte) Keyfunc {
return secureKeyFunc(key, []string{"HS256", "HS384", "HS512"})
}
14 changes: 2 additions & 12 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)

Expand Down Expand Up @@ -60,17 +59,8 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf

// Verify signing method is in the required set
if p.validMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.validMethods {
if m == alg {
signingMethodValid = true
break
}
}
if !signingMethodValid {
// signing method is not in the listed set
return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
if err = token.hasValidSigningMethod(p.validMethods); err != nil {
return token, err
}
}

Expand Down
7 changes: 7 additions & 0 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,10 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte,
return nil, err
}
}

// RSAPublicKey represents a [Keyfunc] that returns the RSA key specified in
// key. Furthermore, it checks, whether the signing method matches
// [SigningMethodRSA].
func RSAPublicKey(key *rsa.PublicKey) Keyfunc {
return secureKeyFunc(key, []string{"RS256", "RS384", "RS512"})
}
3 changes: 2 additions & 1 deletion test/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package test

import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"os"

Expand Down Expand Up @@ -56,7 +57,7 @@ func LoadECPrivateKeyFromDisk(location string) crypto.PrivateKey {
return key
}

func LoadECPublicKeyFromDisk(location string) crypto.PublicKey {
func LoadECPublicKeyFromDisk(location string) *ecdsa.PublicKey {
keyData, e := os.ReadFile(location)
if e != nil {
panic(e.Error())
Expand Down
34 changes: 34 additions & 0 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt
import (
"encoding/base64"
"encoding/json"
"fmt"
)

// Keyfunc will be used by the Parse methods as a callback function to supply
Expand Down Expand Up @@ -81,3 +82,36 @@ func (t *Token) SigningString() (string, error) {
func (*Token) EncodeSegment(seg []byte) string {
return base64.RawURLEncoding.EncodeToString(seg)
}

// hasValidSigningMethod is a utility function that checks, if the signing
// method of the token is included in the validMethods slice.
func (token *Token) hasValidSigningMethod(validMethods []string) error {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range validMethods {
if m == alg {
signingMethodValid = true
break
}
}

if !signingMethodValid {
// signing method is not in the listed set
return newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
}

return nil
}

// secureKeyFunc returns a secure [Keyfunc] for the specified key that also
// includes a signing method check.
func secureKeyFunc(key any, validMethods []string) Keyfunc {
return func(t *Token) (interface{}, error) {
// Check, if the signing method matches
if err := t.hasValidSigningMethod(validMethods); err != nil {
return nil, err
}

return key, nil
}
}
71 changes: 58 additions & 13 deletions token_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package jwt_test
package jwt

import (
"errors"
"reflect"
"testing"

"github.com/golang-jwt/jwt/v5"
)

func TestToken_SigningString(t1 *testing.T) {
type fields struct {
Raw string
Method jwt.SigningMethod
Method SigningMethod
Header map[string]interface{}
Claims jwt.Claims
Claims Claims
Signature []byte
Valid bool
}
Expand All @@ -25,12 +25,12 @@ func TestToken_SigningString(t1 *testing.T) {
name: "",
fields: fields{
Raw: "",
Method: jwt.SigningMethodHS256,
Method: SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
"alg": jwt.SigningMethodHS256.Alg(),
"alg": SigningMethodHS256.Alg(),
},
Claims: jwt.RegisteredClaims{},
Claims: RegisteredClaims{},
Valid: false,
},
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
Expand All @@ -39,7 +39,7 @@ func TestToken_SigningString(t1 *testing.T) {
}
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &jwt.Token{
t := &Token{
Raw: tt.fields.Raw,
Method: tt.fields.Method,
Header: tt.fields.Header,
Expand All @@ -60,13 +60,13 @@ func TestToken_SigningString(t1 *testing.T) {
}

func BenchmarkToken_SigningString(b *testing.B) {
t := &jwt.Token{
Method: jwt.SigningMethodHS256,
t := &Token{
Method: SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
"alg": jwt.SigningMethodHS256.Alg(),
"alg": SigningMethodHS256.Alg(),
},
Claims: jwt.RegisteredClaims{},
Claims: RegisteredClaims{},
}
b.Run("BenchmarkToken_SigningString", func(b *testing.B) {
b.ResetTimer()
Expand All @@ -76,3 +76,48 @@ func BenchmarkToken_SigningString(b *testing.B) {
}
})
}

func Test_secureKeyFunc(t *testing.T) {
type fields struct {
token *Token
}
type args struct {
key any
validMethods []string
}
tests := []struct {
name string
fields fields
args args
wantKey any
wantErr error
}{
{
name: "invalid method",
fields: fields{&Token{Header: map[string]interface{}{"alg": "RS512"}, Method: SigningMethodRS512}},
args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}},
wantKey: nil,
wantErr: ErrTokenSignatureInvalid,
},
{
name: "correct method",
fields: fields{&Token{Header: map[string]interface{}{"alg": "HS256"}, Method: SigningMethodHS256}},
args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}},
wantKey: []byte("mysecret"),
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyfunc := secureKeyFunc(tt.args.key, tt.args.validMethods)
gotKey, gotErr := keyfunc(tt.fields.token)

if !reflect.DeepEqual(gotKey, tt.wantKey) {
t.Errorf("secureKeyFunc() key = %v, want %v", gotKey, tt.wantKey)
}
if (gotErr != nil) && !errors.Is(gotErr, tt.wantErr) {
t.Errorf("secureKeyFunc() err = %v, want %v", gotErr, tt.wantErr)
}
})
}
}