diff --git a/README.md b/README.md index fbbb7b1a..72f20676 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,6 @@ import ( "log" "net/http" - "github.com/auth0/go-jwt-middleware/v2" "github.com/auth0/go-jwt-middleware/v2/validator" jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" ) diff --git a/examples/gin-example/main.go b/examples/gin-example/main.go index 03cc34e2..a9afffc2 100644 --- a/examples/gin-example/main.go +++ b/examples/gin-example/main.go @@ -39,9 +39,29 @@ import ( // "username": "user123", // "shouldReject": true // } +// +// You can also try out the /multiple endpoint. This endpoint accepts tokens signed by multiple issuers. Try the +// token below which has a different issuer: +// +// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1tdWx0aXBsZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtbXVsdGlwbGUtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.9zV_bY1wAmQlMCPlXOppx1Y9_z_T_wNng9-yfQk4I0c +// +// which is signed with 'secret' and has the data: +// +// { +// "iss": "go-jwt-middleware-multiple-example", +// "aud": "audience-multiple-example", +// "sub": "1234567890", +// "name": "John Doe", +// "iat": 1516239022, +// "username": "user123" +// } +// +// You can also try the previous tokens with the /multiple endpoint. The first token will be valid the second will fail because +// the custom validator rejects it (shouldReject: true) func main() { router := gin.Default() + router.GET("/", checkJWT(), func(ctx *gin.Context) { claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) if !ok { @@ -52,7 +72,37 @@ func main() { return } - customClaims, ok := claims.CustomClaims.(*CustomClaimsExample) + localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample) + if !ok { + ctx.AbortWithStatusJSON( + http.StatusInternalServerError, + map[string]string{"message": "Failed to cast custom JWT claims to specific type."}, + ) + return + } + + if len(localCustomClaims.Username) == 0 { + ctx.AbortWithStatusJSON( + http.StatusBadRequest, + map[string]string{"message": "Username in JWT claims was empty."}, + ) + return + } + + ctx.JSON(http.StatusOK, claims) + }) + + router.GET("/multiple", checkJWTMultiple(), func(ctx *gin.Context) { + claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) + if !ok { + ctx.AbortWithStatusJSON( + http.StatusInternalServerError, + map[string]string{"message": "Failed to get validated JWT claims."}, + ) + return + } + + localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample) if !ok { ctx.AbortWithStatusJSON( http.StatusInternalServerError, @@ -61,7 +111,7 @@ func main() { return } - if len(customClaims.Username) == 0 { + if len(localCustomClaims.Username) == 0 { ctx.AbortWithStatusJSON( http.StatusBadRequest, map[string]string{"message": "Username in JWT claims was empty."}, diff --git a/examples/gin-example/middleware.go b/examples/gin-example/middleware.go index 104cd07c..1752f7b2 100644 --- a/examples/gin-example/middleware.go +++ b/examples/gin-example/middleware.go @@ -2,6 +2,7 @@ package main import ( "context" + "gopkg.in/go-jose/go-jose.v2/jwt" "log" "net/http" "time" @@ -16,10 +17,12 @@ var ( signingKey = []byte("secret") // The issuer of our token. - issuer = "go-jwt-middleware-example" + issuer = "go-jwt-middleware-example" + issuerTwo = "go-jwt-middleware-multiple-example" // The audience of our token. - audience = []string{"audience-example"} + audience = []string{"audience-example"} + audienceTwo = []string{"audience-multiple-example"} // Our token must be signed using this data. keyFunc = func(ctx context.Context) (interface{}, error) { @@ -76,3 +79,50 @@ func checkJWT() gin.HandlerFunc { } } } + +func checkJWTMultiple() gin.HandlerFunc { + // Set up the validator. + jwtValidator, err := validator.NewValidator( + keyFunc, + validator.HS256, + validator.WithCustomClaims(customClaims), + validator.WithAllowedClockSkew(30*time.Second), + validator.WithExpectedClaims(jwt.Expected{ + Issuer: issuer, + Audience: audience, + }, jwt.Expected{ + Issuer: issuerTwo, + Audience: audienceTwo, + }), + ) + if err != nil { + log.Fatalf("failed to set up the validator: %v", err) + } + + errorHandler := func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("Encountered error while validating JWT: %v", err) + } + + middleware := jwtmiddleware.New( + jwtValidator.ValidateToken, + jwtmiddleware.WithErrorHandler(errorHandler), + ) + + return func(ctx *gin.Context) { + encounteredError := true + var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + encounteredError = false + ctx.Request = r + ctx.Next() + } + + middleware.CheckJWT(handler).ServeHTTP(ctx.Writer, ctx.Request) + + if encounteredError { + ctx.AbortWithStatusJSON( + http.StatusUnauthorized, + map[string]string{"message": "JWT is invalid."}, + ) + } + } +} diff --git a/examples/http-example/main.go b/examples/http-example/main.go index b7ad5eb9..d824b668 100644 --- a/examples/http-example/main.go +++ b/examples/http-example/main.go @@ -8,9 +8,8 @@ import ( "net/http" "time" - "github.com/auth0/go-jwt-middleware/v2" - "github.com/auth0/go-jwt-middleware/v2/validator" jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" + "github.com/auth0/go-jwt-middleware/v2/validator" ) var ( diff --git a/examples/http-jwks-example/main.go b/examples/http-jwks-example/main.go index 81776dcc..93ee1440 100644 --- a/examples/http-jwks-example/main.go +++ b/examples/http-jwks-example/main.go @@ -7,10 +7,9 @@ import ( "net/url" "time" - "github.com/auth0/go-jwt-middleware/v2" + jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" "github.com/auth0/go-jwt-middleware/v2/jwks" "github.com/auth0/go-jwt-middleware/v2/validator" - jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" ) var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/extractor.go b/extractor.go index 376e513c..33882665 100644 --- a/extractor.go +++ b/extractor.go @@ -23,7 +23,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) { authHeaderParts := strings.Fields(authHeader) if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") + return "", errors.New("authorization header format must be Bearer {token}") } return authHeaderParts[1], nil @@ -34,7 +34,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) { func CookieTokenExtractor(cookieName string) TokenExtractor { return func(r *http.Request) (string, error) { cookie, err := r.Cookie(cookieName) - if err == http.ErrNoCookie { + if errors.Is(err, http.ErrNoCookie) { return "", nil // No cookie, then no JWT, so no error. } diff --git a/extractor_test.go b/extractor_test.go index 3101847d..adca0443 100644 --- a/extractor_test.go +++ b/extractor_test.go @@ -38,7 +38,7 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) { "Authorization": []string{"i-am-a-token"}, }, }, - wantError: "Authorization header format must be Bearer {token}", + wantError: "authorization header format must be Bearer {token}", }, } diff --git a/jwks/provider.go b/jwks/provider.go index 808cae75..aa30b9b8 100644 --- a/jwks/provider.go +++ b/jwks/provider.go @@ -21,9 +21,10 @@ import ( // getting and caching JWKS which can help reduce request time and potential // rate limiting from your provider. type Provider struct { - IssuerURL *url.URL // Required. - CustomJWKSURI *url.URL // Optional. - Client *http.Client + IssuerURL *url.URL // Required. + CustomJWKSURI *url.URL // Optional. + AdditionalProviders []Provider // Optional + Client *http.Client } // ProviderOption is how options for the Provider are set up. @@ -32,14 +33,24 @@ type ProviderOption func(*Provider) // NewProvider builds and returns a new *Provider. func NewProvider(issuerURL *url.URL, opts ...ProviderOption) *Provider { p := &Provider{ - IssuerURL: issuerURL, - Client: &http.Client{}, + Client: &http.Client{}, + AdditionalProviders: make([]Provider, 0), + } + + if issuerURL != nil { + p.IssuerURL = issuerURL } for _, opt := range opts { opt(p) } + for _, provider := range p.AdditionalProviders { + if provider.Client == nil { + provider.Client = p.Client + } + } + return p } @@ -56,6 +67,21 @@ func WithCustomJWKSURI(jwksURI *url.URL) ProviderOption { func WithCustomClient(c *http.Client) ProviderOption { return func(p *Provider) { p.Client = c + for _, provider := range p.AdditionalProviders { + provider.Client = c + } + } +} + +// WithAdditionalProviders allows validation with mutliple IssuerURLs if desired. If multiple issuers are specified, +// a jwt may be signed by any of them and be considered valid +func WithAdditionalProviders(issuerURL *url.URL, customJWKSURI *url.URL) ProviderOption { + return func(p *Provider) { + p.AdditionalProviders = append(p.AdditionalProviders, Provider{ + IssuerURL: issuerURL, + CustomJWKSURI: customJWKSURI, + Client: p.Client, + }) } } @@ -63,6 +89,25 @@ func WithCustomClient(c *http.Client) ProviderOption { // While it returns an interface to adhere to keyFunc, as long as the // error is nil the type will be *jose.JSONWebKeySet. func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) { + rawJwks, err := p.keyFunc(ctx) + + if len(p.AdditionalProviders) == 0 { + return rawJwks, err + } else { + var jwks *jose.JSONWebKeySet + jwks = rawJwks.(*jose.JSONWebKeySet) + for _, provider := range p.AdditionalProviders { + if rawJwks, err = provider.keyFunc(ctx); err != nil { + continue + } else { + jwks.Keys = append(jwks.Keys, rawJwks.(*jose.JSONWebKeySet).Keys...) + } + } + return jwks, err + } +} + +func (p *Provider) keyFunc(ctx context.Context) (interface{}, error) { jwksURI := p.CustomJWKSURI if jwksURI == nil { wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.Client, *p.IssuerURL) @@ -85,10 +130,12 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) { if err != nil { return nil, err } - defer response.Body.Close() + defer func() { + _ = response.Body.Close() + }() var jwks jose.JSONWebKeySet - if err := json.NewDecoder(response.Body).Decode(&jwks); err != nil { + if err = json.NewDecoder(response.Body).Decode(&jwks); err != nil { return nil, fmt.Errorf("could not decode jwks: %w", err) } diff --git a/validator/option.go b/validator/option.go index 12c1cc61..bd318299 100644 --- a/validator/option.go +++ b/validator/option.go @@ -1,6 +1,7 @@ package validator import ( + "gopkg.in/go-jose/go-jose.v2/jwt" "time" ) @@ -26,3 +27,16 @@ func WithCustomClaims(f func() CustomClaims) Option { v.customClaims = f } } + +// WithExpectedClaims allows fine-grained customization of the expected claims +func WithExpectedClaims(expectedClaims ...jwt.Expected) Option { + return func(v *Validator) { + if len(expectedClaims) == 0 { + return + } + if v.expectedClaims == nil { + v.expectedClaims = make([]jwt.Expected, 0) + } + v.expectedClaims = append(v.expectedClaims, expectedClaims...) + } +} diff --git a/validator/validator.go b/validator/validator.go index 2a302493..b28dc948 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "gopkg.in/go-jose/go-jose.v2/jwt" @@ -30,7 +31,7 @@ const ( type Validator struct { keyFunc func(context.Context) (interface{}, error) // Required. signatureAlgorithm SignatureAlgorithm // Required. - expectedClaims jwt.Expected // Internal. + expectedClaims []jwt.Expected // Internal. customClaims func() CustomClaims // Optional. allowedClockSkew time.Duration // Optional. } @@ -66,11 +67,61 @@ func New( if keyFunc == nil { return nil, errors.New("keyFunc is required but was nil") } - if issuerURL == "" { - return nil, errors.New("issuer url is required but was empty") + if _, ok := allowedSigningAlgorithms[signatureAlgorithm]; !ok { + return nil, errors.New("unsupported signature algorithm") + } + + v := &Validator{ + keyFunc: keyFunc, + signatureAlgorithm: signatureAlgorithm, + expectedClaims: make([]jwt.Expected, 0), } - if len(audience) == 0 { + + for _, opt := range opts { + opt(v) + } + + if len(v.expectedClaims) == 0 && issuerURL == "" { + return nil, errors.New("issuer url is required but was empty") + } else if len(v.expectedClaims) == 0 && len(audience) == 0 { return nil, errors.New("audience is required but was empty") + } else if len(issuerURL) > 0 && len(audience) > 0 { + v.expectedClaims = append(v.expectedClaims, jwt.Expected{ + Issuer: issuerURL, + Audience: audience, + }) + } + + if len(v.expectedClaims) == 0 { + return nil, errors.New("expected claims but none provided") + } + + for i, expected := range v.expectedClaims { + if expected.Issuer == "" { + return nil, fmt.Errorf("issuer url %d is required but was empty", i) + } + if len(expected.Audience) == 0 { + return nil, fmt.Errorf("audience %d is required but was empty", i) + } + } + + return v, nil +} + +// NewValidator sets up a new Validator with the required keyFunc +// and signatureAlgorithm as well as custom options. +// This function has been added to provide an alternate function without the required issuer or audience parameters +// so they can be included in the opts parameter via WithExpectedClaims +// This function operates exactly like New with the exception of the two parameters issuer and audience and this function +// expects the inclusion of WithExpectedClaims with at least one valid expected claim. +// A valid expected claim would include an issuer and at least one audience +func NewValidator( + keyFunc func(context.Context) (interface{}, error), + signatureAlgorithm SignatureAlgorithm, + opts ...Option, +) (*Validator, error) { + if keyFunc == nil { + return nil, errors.New("keyFunc is required but was nil") } if _, ok := allowedSigningAlgorithms[signatureAlgorithm]; !ok { return nil, errors.New("unsupported signature algorithm") @@ -79,16 +130,26 @@ func New( v := &Validator{ keyFunc: keyFunc, signatureAlgorithm: signatureAlgorithm, - expectedClaims: jwt.Expected{ - Issuer: issuerURL, - Audience: audience, - }, + expectedClaims: make([]jwt.Expected, 0), } for _, opt := range opts { opt(v) } + if len(v.expectedClaims) == 0 { + return nil, errors.New("expected claims but none provided") + } + + for i, expected := range v.expectedClaims { + if expected.Issuer == "" { + return nil, fmt.Errorf("issuer url %d is required but was empty", i) + } + if len(expected.Audience) == 0 { + return nil, fmt.Errorf("audience %d is required but was empty", i) + } + } + return v, nil } @@ -134,38 +195,74 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte return validatedClaims, nil } -func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration) error { - expectedClaims := expected - expectedClaims.Time = time.Now() +func validateClaimsWithLeeway(actualClaims jwt.Claims, expectedIn []jwt.Expected, leeway time.Duration) error { + now := time.Now() + var currentError error + for _, expected := range expectedIn { + expectedClaims := expected + expectedClaims.Time = now - if actualClaims.Issuer != expectedClaims.Issuer { - return jwt.ErrInvalidIssuer - } + if actualClaims.Issuer != expectedClaims.Issuer { + currentError = createOrWrapError(currentError, jwt.ErrInvalidIssuer, actualClaims.Issuer, expectedClaims.Issuer) + continue + } - foundAudience := false - for _, value := range expectedClaims.Audience { - if actualClaims.Audience.Contains(value) { - foundAudience = true - break + foundAudience := false + for _, value := range expectedClaims.Audience { + if actualClaims.Audience.Contains(value) { + foundAudience = true + break + } + } + if !foundAudience { + currentError = createOrWrapError( + currentError, + jwt.ErrInvalidAudience, + strings.Join(actualClaims.Audience, ","), + strings.Join(expectedClaims.Audience, ","), + ) + continue } - } - if !foundAudience { - return jwt.ErrInvalidAudience - } - if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) { - return jwt.ErrNotValidYet - } + if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) { + return createOrWrapError( + currentError, + jwt.ErrNotValidYet, + actualClaims.NotBefore.Time().String(), + expectedClaims.Time.Add(leeway).String(), + ) + } - if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) { - return jwt.ErrExpired + if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) { + return createOrWrapError( + currentError, + jwt.ErrExpired, + actualClaims.Expiry.Time().String(), + expectedClaims.Time.Add(leeway).String(), + ) + } + + if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) { + return createOrWrapError( + currentError, + jwt.ErrIssuedInTheFuture, + actualClaims.IssuedAt.Time().String(), + expectedClaims.Time.Add(leeway).String(), + ) + } + + return nil } - if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) { - return jwt.ErrIssuedInTheFuture + return currentError +} + +func createOrWrapError(base, current error, actual, expected string) error { + if base == nil { + return current } - return nil + return errors.Join(base, fmt.Errorf("%v: %s vs %s", current, actual, expected)) } func validateSigningMethod(validAlg, tokenAlg string) error { diff --git a/validator/validator_test.go b/validator/validator_test.go index 08feeb14..84d986b2 100644 --- a/validator/validator_test.go +++ b/validator/validator_test.go @@ -234,7 +234,228 @@ func TestValidator_ValidateToken(t *testing.T) { } } -func TestNewValidator(t *testing.T) { +func TestNewValidator_ValidateToken(t *testing.T) { + const ( + issuer = "https://go-jwt-middleware.eu.auth0.com/" + audience = "https://go-jwt-middleware-api/" + subject = "1234567890" + issuerB = "https://go-jwt-middleware.us.auth0.com/" + audienceB = "https://go-jwt-middleware-api-b/" + subjectB = "0987654321" + ) + + testCases := []struct { + name string + token string + keyFunc func(context.Context) (interface{}, error) + algorithm SignatureAlgorithm + customClaims func() CustomClaims + expectedError error + expectedClaims *ValidatedClaims + }{ + { + name: "it successfully validates a token", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.-R2K2tZHDrgsEh9JNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + }, + }, + }, + { + name: "it successfully validates a token with custom claims", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJzY29wZSI6InJlYWQ6bWVzc2FnZXMifQ.oqtUZQ-Q8un4CPduUBdGVq5gXpQVIFT_QSQjkOXFT5I", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + customClaims: func() CustomClaims { + return &testClaims{} + }, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + }, + CustomClaims: &testClaims{ + Scope: "read:messages", + }, + }, + }, + { + name: "it throws an error when token has a different signing algorithm than the validator", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.-R2K2tZHDrgsEh9JNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: RS256, + expectedError: errors.New(`signing method is invalid: expected "RS256" signing algorithm but token specified "HS256"`), + }, + { + name: "it throws an error when it cannot parse the token", + token: "", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: errors.New("could not parse the token: go-jose/go-jose: compact JWS format must have three parts"), + }, + { + name: "it throws an error when it fails to fetch the keys from the key func", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.-R2K2tZHDrgsEh9JNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return nil, errors.New("key func error message") + }, + algorithm: HS256, + expectedError: errors.New("failed to deserialize token claims: error getting the keys from the key func: key func error message"), + }, + { + name: "it throws an error when it fails to deserialize the claims because the signature is invalid", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.vR2K2tZHDrgsEh9zNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: errors.New("failed to deserialize token claims: could not get token claims: go-jose/go-jose: error in cryptographic primitive"), + }, + { + name: "it throws an error when it fails to validate the registered claims", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIn0.VoIwDVmb--26wGrv93NmjNZYa4nrzjLw4JANgEjPI28", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: errors.New("go-jose/go-jose/jwt: validation failed, invalid audience claim (aud)"), + }, + { + name: "it throws an error when it fails to validate the custom claims", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJzY29wZSI6InJlYWQ6bWVzc2FnZXMifQ.oqtUZQ-Q8un4CPduUBdGVq5gXpQVIFT_QSQjkOXFT5I", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + customClaims: func() CustomClaims { + return &testClaims{ + ReturnError: errors.New("custom claims error message"), + } + }, + expectedError: errors.New("custom claims not validated: custom claims error message"), + }, + { + name: "it successfully validates a token even if customClaims() returns nil", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJzY29wZSI6InJlYWQ6bWVzc2FnZXMifQ.oqtUZQ-Q8un4CPduUBdGVq5gXpQVIFT_QSQjkOXFT5I", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + customClaims: func() CustomClaims { + return nil + }, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + }, + CustomClaims: nil, + }, + }, + { + name: "it successfully validates a token with exp, nbf and iat", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo5NjY3OTM3Njg2fQ.FKZogkm08gTfYfPU6eYu7OHCjJKnKGLiC0IfoIOPEhs", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + Expiry: 9667937686, + NotBefore: 1666939000, + IssuedAt: 1666937686, + }, + }, + }, + { + name: "it throws an error when token is not valid yet", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6OTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.yUizJ-zK_33tv1qBVvDKO0RuCWtvJ02UQKs8gBadgGY", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrNotValidYet), + }, + { + name: "it throws an error when token is expired", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo2Njc5Mzc2ODZ9.SKvz82VOXRi_sjvZWIsPG9vSWAXKKgVS4DkGZcwFKL8", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrExpired), + }, + { + name: "it throws an error when token is issued in the future", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjkxNjY2OTM3Njg2LCJuYmYiOjE2NjY5MzkwMDAsImV4cCI6ODY2NzkzNzY4Nn0.ieFV7XNJxiJyw8ARq9yHw-01Oi02e3P2skZO10ypxL8", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrIssuedInTheFuture), + }, + { + name: "it throws an error when token issuer is invalid", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2hhY2tlZC1qd3QtbWlkZGxld2FyZS5ldS5hdXRoMC5jb20vIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6WyJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLWFwaS8iXSwiaWF0Ijo5MTY2NjkzNzY4NiwibmJmIjoxNjY2OTM5MDAwLCJleHAiOjg2Njc5Mzc2ODZ9.b5gXNrUNfd_jyCWZF-6IPK_UFfvTr9wBQk9_QgRQ8rA", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrInvalidIssuer), + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + validator, err := NewValidator( + testCase.keyFunc, + testCase.algorithm, + WithCustomClaims(testCase.customClaims), + WithAllowedClockSkew(time.Second), + WithExpectedClaims(jwt.Expected{ + Issuer: issuer, + Audience: []string{audience, "another-audience"}, + }, jwt.Expected{ + Issuer: issuerB, + Audience: []string{audienceB, "another-audienceb"}, + }), + ) + require.NoError(t, err) + + tokenClaims, err := validator.ValidateToken(context.Background(), testCase.token) + if testCase.expectedError != nil { + assert.ErrorContains(t, err, testCase.expectedError.Error()) + assert.Nil(t, tokenClaims) + } else { + require.NoError(t, err) + assert.Exactly(t, testCase.expectedClaims, tokenClaims) + } + }) + } +} + +func TestNew(t *testing.T) { const ( issuer = "https://go-jwt-middleware.eu.auth0.com/" audience = "https://go-jwt-middleware-api/" @@ -260,12 +481,12 @@ func TestNewValidator(t *testing.T) { assert.EqualError(t, err, "unsupported signature algorithm") }) - t.Run("it throws an error when the issuerURL is empty", func(t *testing.T) { + t.Run("it throws an error when the issuerURL is empty and no expectedClaims option", func(t *testing.T) { _, err := New(keyFunc, algorithm, "", []string{audience}) assert.EqualError(t, err, "issuer url is required but was empty") }) - t.Run("it throws an error when the audience is nil", func(t *testing.T) { + t.Run("it throws an error when the audience is nil if no expectedClaims option included", func(t *testing.T) { _, err := New(keyFunc, algorithm, issuer, nil) assert.EqualError(t, err, "audience is required but was empty") }) @@ -274,4 +495,81 @@ func TestNewValidator(t *testing.T) { _, err := New(keyFunc, algorithm, issuer, []string{}) assert.EqualError(t, err, "audience is required but was empty") }) + + t.Run("it throws an error when the issuerURL is empty and an expectedClaims option with only an audience", func(t *testing.T) { + _, err := New(keyFunc, algorithm, "", []string{}, WithExpectedClaims(jwt.Expected{Audience: []string{audience}})) + assert.EqualError(t, err, "issuer url 0 is required but was empty") + }) + + t.Run("it throws an error when the audience is empty and the expectedClaims are missing an audience", func(t *testing.T) { + _, err := New(keyFunc, algorithm, issuer, []string{}, WithExpectedClaims(jwt.Expected{Issuer: issuer})) + assert.EqualError(t, err, "audience 0 is required but was empty") + }) + + t.Run("it throws no error when the issuerURL is empty but expectedClaims option included", func(t *testing.T) { + _, err := New(keyFunc, algorithm, "", []string{audience}, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}})) + assert.NoError(t, err, "no error was expected") + }) + + t.Run("it throws no error when the audience is nil but expectedClaims option included", func(t *testing.T) { + _, err := New(keyFunc, algorithm, issuer, nil, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}})) + assert.NoError(t, err, "no error was expected") + }) +} + +func TestNewValidator(t *testing.T) { + const ( + issuer = "https://go-jwt-middleware.eu.auth0.com/" + audience = "https://go-jwt-middleware-api/" + algorithm = HS256 + ) + + var keyFunc = func(context.Context) (interface{}, error) { + return []byte("secret"), nil + } + + t.Run("it throws an error when the keyFunc is nil", func(t *testing.T) { + _, err := NewValidator(nil, algorithm) + assert.EqualError(t, err, "keyFunc is required but was nil") + }) + + t.Run("it throws an error when the signature algorithm is empty", func(t *testing.T) { + _, err := NewValidator(keyFunc, "") + assert.EqualError(t, err, "unsupported signature algorithm") + }) + + t.Run("it throws an error when the signature algorithm is unsupported", func(t *testing.T) { + _, err := NewValidator(keyFunc, "none") + assert.EqualError(t, err, "unsupported signature algorithm") + }) + + t.Run("it throws an error when there are no expected claims", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm) + assert.EqualError(t, err, "expected claims but none provided") + }) + + t.Run("it throws an error when expectedClaims option with only an audience", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Audience: []string{audience}})) + assert.EqualError(t, err, "issuer url 0 is required but was empty") + }) + + t.Run("it throws an error when expectedClaims option with only an audience in the second jwt.Expected", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}}, jwt.Expected{Audience: []string{audience}})) + assert.EqualError(t, err, "issuer url 1 is required but was empty") + }) + + t.Run("it throws an error when the audience is empty and the expectedClaims are missing an audience", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer})) + assert.EqualError(t, err, "audience 0 is required but was empty") + }) + + t.Run("it throws an error when the audience is empty and the expectedClaims are missing an audience in the second jwt.Expected", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}}, jwt.Expected{Issuer: issuer})) + assert.EqualError(t, err, "audience 1 is required but was empty") + }) + + t.Run("it throws no error when input is correct", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}})) + assert.NoError(t, err, "no error was expected") + }) }