Skip to content

Commit

Permalink
feat: support multiple issuer:audience combinations by introducing an…
Browse files Browse the repository at this point in the history
… option for the expectedClaims. WithExpectedClaims can be called with multiple jwt.Expected parameters to allow different Issuer:Audience combinations to validate tokens

feat: support multiple issuers in a provider using WithAdditionalIssuers option

Every effort has been made to ensure backwards compatibility. Some error messages will be different due to the wrapping of errors when multiple jwt.Expected are set. When validating the jwt, if an error is encountered, instead of returning immediately, the current error is wrapped. This is good and bad. Good because all verification failure causes are captured in a single wrapped error; Bad because all verification failure causes are captured in a single monolithic wrapped error. Unwrapping the error can be tedious if many jwt.Expected are included. There is likely a better way but this suits my purposes.

A few more test cases will likely be needed in order to achieve true confidence in this change
  • Loading branch information
cmmoran committed Nov 9, 2024
1 parent 4f02637 commit 0463741
Show file tree
Hide file tree
Showing 11 changed files with 606 additions and 53 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
"log"
"net/http"

"github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
)
Expand Down
54 changes: 52 additions & 2 deletions examples/gin-example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,29 @@ import (
// "username": "user123",
// "shouldReject": true
// }
//
// You can also try out the /multiple endpoint. This endpoint accepts tokens signed by multiple issuers. Try the
// token below which has a different issuer:
//
// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1tdWx0aXBsZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtbXVsdGlwbGUtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.9zV_bY1wAmQlMCPlXOppx1Y9_z_T_wNng9-yfQk4I0c
//
// which is signed with 'secret' and has the data:
//
// {
// "iss": "go-jwt-middleware-multiple-example",
// "aud": "audience-multiple-example",
// "sub": "1234567890",
// "name": "John Doe",
// "iat": 1516239022,
// "username": "user123"
// }
//
// You can also try the previous tokens with the /multiple endpoint. The first token will be valid the second will fail because
// the custom validator rejects it (shouldReject: true)

func main() {
router := gin.Default()

router.GET("/", checkJWT(), func(ctx *gin.Context) {
claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
if !ok {
Expand All @@ -52,7 +72,37 @@ func main() {
return
}

customClaims, ok := claims.CustomClaims.(*CustomClaimsExample)
localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample)
if !ok {
ctx.AbortWithStatusJSON(
http.StatusInternalServerError,
map[string]string{"message": "Failed to cast custom JWT claims to specific type."},
)
return
}

if len(localCustomClaims.Username) == 0 {
ctx.AbortWithStatusJSON(
http.StatusBadRequest,
map[string]string{"message": "Username in JWT claims was empty."},
)
return
}

ctx.JSON(http.StatusOK, claims)
})

router.GET("/multiple", checkJWTMultiple(), func(ctx *gin.Context) {
claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
if !ok {
ctx.AbortWithStatusJSON(
http.StatusInternalServerError,
map[string]string{"message": "Failed to get validated JWT claims."},
)
return
}

localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample)
if !ok {
ctx.AbortWithStatusJSON(
http.StatusInternalServerError,
Expand All @@ -61,7 +111,7 @@ func main() {
return
}

if len(customClaims.Username) == 0 {
if len(localCustomClaims.Username) == 0 {
ctx.AbortWithStatusJSON(
http.StatusBadRequest,
map[string]string{"message": "Username in JWT claims was empty."},
Expand Down
54 changes: 52 additions & 2 deletions examples/gin-example/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"gopkg.in/go-jose/go-jose.v2/jwt"
"log"
"net/http"
"time"
Expand All @@ -16,10 +17,12 @@ var (
signingKey = []byte("secret")

// The issuer of our token.
issuer = "go-jwt-middleware-example"
issuer = "go-jwt-middleware-example"
issuerTwo = "go-jwt-middleware-multiple-example"

// The audience of our token.
audience = []string{"audience-example"}
audience = []string{"audience-example"}
audienceTwo = []string{"audience-multiple-example"}

// Our token must be signed using this data.
keyFunc = func(ctx context.Context) (interface{}, error) {
Expand Down Expand Up @@ -76,3 +79,50 @@ func checkJWT() gin.HandlerFunc {
}
}
}

func checkJWTMultiple() gin.HandlerFunc {
// Set up the validator.
jwtValidator, err := validator.NewValidator(
keyFunc,
validator.HS256,
validator.WithCustomClaims(customClaims),
validator.WithAllowedClockSkew(30*time.Second),
validator.WithExpectedClaims(jwt.Expected{
Issuer: issuer,
Audience: audience,
}, jwt.Expected{
Issuer: issuerTwo,
Audience: audienceTwo,
}),
)
if err != nil {
log.Fatalf("failed to set up the validator: %v", err)
}

errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("Encountered error while validating JWT: %v", err)
}

middleware := jwtmiddleware.New(
jwtValidator.ValidateToken,
jwtmiddleware.WithErrorHandler(errorHandler),
)

return func(ctx *gin.Context) {
encounteredError := true
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
encounteredError = false
ctx.Request = r
ctx.Next()
}

middleware.CheckJWT(handler).ServeHTTP(ctx.Writer, ctx.Request)

if encounteredError {
ctx.AbortWithStatusJSON(
http.StatusUnauthorized,
map[string]string{"message": "JWT is invalid."},
)
}
}
}
3 changes: 1 addition & 2 deletions examples/http-example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ import (
"net/http"
"time"

"github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/validator"
)

var (
Expand Down
3 changes: 1 addition & 2 deletions examples/http-jwks-example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ import (
"net/url"
"time"

"github.com/auth0/go-jwt-middleware/v2"
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/auth0/go-jwt-middleware/v2/jwks"
"github.com/auth0/go-jwt-middleware/v2/validator"
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
)

var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
4 changes: 2 additions & 2 deletions extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) {

authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
return "", errors.New("authorization header format must be Bearer {token}")
}

return authHeaderParts[1], nil
Expand All @@ -34,7 +34,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) {
func CookieTokenExtractor(cookieName string) TokenExtractor {
return func(r *http.Request) (string, error) {
cookie, err := r.Cookie(cookieName)
if err == http.ErrNoCookie {
if errors.Is(err, http.ErrNoCookie) {
return "", nil // No cookie, then no JWT, so no error.
}

Expand Down
2 changes: 1 addition & 1 deletion extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) {
"Authorization": []string{"i-am-a-token"},
},
},
wantError: "Authorization header format must be Bearer {token}",
wantError: "authorization header format must be Bearer {token}",
},
}

Expand Down
61 changes: 54 additions & 7 deletions jwks/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import (
// getting and caching JWKS which can help reduce request time and potential
// rate limiting from your provider.
type Provider struct {
IssuerURL *url.URL // Required.
CustomJWKSURI *url.URL // Optional.
Client *http.Client
IssuerURL *url.URL // Required.
CustomJWKSURI *url.URL // Optional.
AdditionalProviders []Provider // Optional
Client *http.Client
}

// ProviderOption is how options for the Provider are set up.
Expand All @@ -32,14 +33,24 @@ type ProviderOption func(*Provider)
// NewProvider builds and returns a new *Provider.
func NewProvider(issuerURL *url.URL, opts ...ProviderOption) *Provider {
p := &Provider{
IssuerURL: issuerURL,
Client: &http.Client{},
Client: &http.Client{},
AdditionalProviders: make([]Provider, 0),
}

if issuerURL != nil {
p.IssuerURL = issuerURL
}

for _, opt := range opts {
opt(p)
}

for _, provider := range p.AdditionalProviders {
if provider.Client == nil {
provider.Client = p.Client
}
}

return p
}

Expand All @@ -56,13 +67,47 @@ func WithCustomJWKSURI(jwksURI *url.URL) ProviderOption {
func WithCustomClient(c *http.Client) ProviderOption {
return func(p *Provider) {
p.Client = c
for _, provider := range p.AdditionalProviders {
provider.Client = c
}
}
}

// WithAdditionalProviders allows validation with mutliple IssuerURLs if desired. If multiple issuers are specified,
// a jwt may be signed by any of them and be considered valid
func WithAdditionalProviders(issuerURL *url.URL, customJWKSURI *url.URL) ProviderOption {
return func(p *Provider) {
p.AdditionalProviders = append(p.AdditionalProviders, Provider{
IssuerURL: issuerURL,
CustomJWKSURI: customJWKSURI,
Client: p.Client,
})
}
}

// KeyFunc adheres to the keyFunc signature that the Validator requires.
// While it returns an interface to adhere to keyFunc, as long as the
// error is nil the type will be *jose.JSONWebKeySet.
func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
rawJwks, err := p.keyFunc(ctx)

if len(p.AdditionalProviders) == 0 {
return rawJwks, err
} else {
var jwks *jose.JSONWebKeySet
jwks = rawJwks.(*jose.JSONWebKeySet)
for _, provider := range p.AdditionalProviders {
if rawJwks, err = provider.keyFunc(ctx); err != nil {
continue
} else {
jwks.Keys = append(jwks.Keys, rawJwks.(*jose.JSONWebKeySet).Keys...)
}
}
return jwks, err
}
}

func (p *Provider) keyFunc(ctx context.Context) (interface{}, error) {
jwksURI := p.CustomJWKSURI
if jwksURI == nil {
wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.Client, *p.IssuerURL)
Expand All @@ -85,10 +130,12 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
if err != nil {
return nil, err
}
defer response.Body.Close()
defer func() {
_ = response.Body.Close()
}()

var jwks jose.JSONWebKeySet
if err := json.NewDecoder(response.Body).Decode(&jwks); err != nil {
if err = json.NewDecoder(response.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("could not decode jwks: %w", err)
}

Expand Down
14 changes: 14 additions & 0 deletions validator/option.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package validator

import (
"gopkg.in/go-jose/go-jose.v2/jwt"
"time"
)

Expand All @@ -26,3 +27,16 @@ func WithCustomClaims(f func() CustomClaims) Option {
v.customClaims = f
}
}

// WithExpectedClaims allows fine-grained customization of the expected claims
func WithExpectedClaims(expectedClaims ...jwt.Expected) Option {
return func(v *Validator) {
if len(expectedClaims) == 0 {
return
}
if v.expectedClaims == nil {
v.expectedClaims = make([]jwt.Expected, 0)
}
v.expectedClaims = append(v.expectedClaims, expectedClaims...)
}
}
Loading

0 comments on commit 0463741

Please sign in to comment.