From 323e7eaed357dc9cb22ed31d2937d905b17ec685 Mon Sep 17 00:00:00 2001 From: zhuojie Date: Thu, 13 Sep 2018 10:20:05 -0700 Subject: [PATCH] Fix 401 jwt middleware handling --- pkg/config/env.go | 17 ++++----- pkg/config/middleware.go | 33 ++++++++++++++---- pkg/config/middleware_test.go | 63 +++++++++++++++++++++++++++++++++- pkg/handler/eval_cache_test.go | 13 +------ 4 files changed, 98 insertions(+), 28 deletions(-) diff --git a/pkg/config/env.go b/pkg/config/env.go index 096762c1..eb7ed4ea 100644 --- a/pkg/config/env.go +++ b/pkg/config/env.go @@ -107,14 +107,15 @@ var Config = struct { Note: If the access_token is present in both the header and cookie only the latest will be used */ - JWTAuthEnabled bool `env:"FLAGR_JWT_AUTH_ENABLED" envDefault:"false"` - JWTAuthDebug bool `env:"FLAGR_JWT_AUTH_DEBUG" envDefault:"false"` - JWTAuthWhitelistPaths []string `env:"FLAGR_JWT_AUTH_WHITELIST_PATHS" envDefault:"/api/v1/evaluation" envSeparator:","` - JWTAuthCookieTokenName string `env:"FLAGR_JWT_AUTH_COOKIE_TOKEN_NAME" envDefault:"access_token"` - JWTAuthSecret string `env:"FLAGR_JWT_AUTH_SECRET" envDefault:""` - JWTAuthNoTokenStatusCode int `env:"FLAGR_JWT_AUTH_NO_TOKEN_STATUS_CODE" envDefault:"307"` // "307" or "401" - JWTAuthNoTokenRedirectURL string `env:"FLAGR_JWT_AUTH_NO_TOKEN_REDIRECT_URL" envDefault:""` - JWTAuthUserProperty string `env:"FLAGR_JWT_AUTH_USER_PROPERTY" envDefault:"flagr_user"` + JWTAuthEnabled bool `env:"FLAGR_JWT_AUTH_ENABLED" envDefault:"false"` + JWTAuthDebug bool `env:"FLAGR_JWT_AUTH_DEBUG" envDefault:"false"` + JWTAuthPrefixWhitelistPaths []string `env:"FLAGR_JWT_AUTH_WHITELIST_PATHS" envDefault:"/api/v1/evaluation,/static" envSeparator:","` + JWTAuthExactWhitelistPaths []string `env:"FLAGR_JWT_AUTH_EXACT_WHITELIST_PATHS" envDefault:",/" envSeparator:","` + JWTAuthCookieTokenName string `env:"FLAGR_JWT_AUTH_COOKIE_TOKEN_NAME" envDefault:"access_token"` + JWTAuthSecret string `env:"FLAGR_JWT_AUTH_SECRET" envDefault:""` + JWTAuthNoTokenStatusCode int `env:"FLAGR_JWT_AUTH_NO_TOKEN_STATUS_CODE" envDefault:"307"` // "307" or "401" + JWTAuthNoTokenRedirectURL string `env:"FLAGR_JWT_AUTH_NO_TOKEN_REDIRECT_URL" envDefault:""` + JWTAuthUserProperty string `env:"FLAGR_JWT_AUTH_USER_PROPERTY" envDefault:"flagr_user"` // "HS256" and "RS256" supported JWTAuthSigningMethod string `env:"FLAGR_JWT_AUTH_SIGNING_METHOD" envDefault:"HS256"` diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 42854581..a0ef52c6 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -80,7 +80,8 @@ func setupJWTAuthMiddleware() *auth { } return &auth{ - WhitelistPaths: Config.JWTAuthWhitelistPaths, + PrefixWhitelistPaths: Config.JWTAuthPrefixWhitelistPaths, + ExactWhitelistPaths: Config.JWTAuthExactWhitelistPaths, JWTMiddleware: jwtmiddleware.New(jwtmiddleware.Options{ ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { return validationKey, errParsingKey @@ -116,18 +117,36 @@ func jwtErrorHandler(w http.ResponseWriter, r *http.Request, err string) { } type auth struct { - WhitelistPaths []string - JWTMiddleware *jwtmiddleware.JWTMiddleware + PrefixWhitelistPaths []string + ExactWhitelistPaths []string + JWTMiddleware *jwtmiddleware.JWTMiddleware } -func (a *auth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) { +func (a *auth) whitelist(req *http.Request) bool { path := req.URL.Path - for _, p := range a.WhitelistPaths { + + // If we set to 401 unauthorized, let the client handles the 401 itself + if Config.JWTAuthNoTokenStatusCode == http.StatusUnauthorized { + for _, p := range a.ExactWhitelistPaths { + if p == path { + return true + } + } + } + + for _, p := range a.PrefixWhitelistPaths { if p != "" && strings.HasPrefix(path, p) { - next(w, req) - return + return true } } + return false +} + +func (a *auth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) { + if a.whitelist(req) { + next(w, req) + return + } a.JWTMiddleware.HandlerWithNext(w, req, next) } diff --git a/pkg/config/middleware_test.go b/pkg/config/middleware_test.go index b1365be3..db75347b 100644 --- a/pkg/config/middleware_test.go +++ b/pkg/config/middleware_test.go @@ -113,7 +113,7 @@ func TestAuthMiddleware(t *testing.T) { res := httptest.NewRecorder() res.Body = new(bytes.Buffer) - req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:18000%s", Config.JWTAuthWhitelistPaths[0]), nil) + req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:18000%s", Config.JWTAuthPrefixWhitelistPaths[0]), nil) hh.ServeHTTP(res, req) assert.Equal(t, http.StatusOK, res.Code) }) @@ -233,6 +233,67 @@ o2kQ+X5xK9cipRgEKwIDAQAB }) } +func TestAuthMiddlewareWithUnauthorized(t *testing.T) { + h := &okHandler{} + + t.Run("it will return 401 if no cookie passed", func(t *testing.T) { + Config.JWTAuthEnabled = true + Config.JWTAuthNoTokenStatusCode = http.StatusUnauthorized + defer func() { + Config.JWTAuthEnabled = false + Config.JWTAuthNoTokenStatusCode = http.StatusTemporaryRedirect + }() + + hh := SetupGlobalMiddleware(h) + res := httptest.NewRecorder() + res.Body = new(bytes.Buffer) + req, _ := http.NewRequest("GET", "http://localhost:18000/api/v1/flags", nil) + hh.ServeHTTP(res, req) + assert.Equal(t, http.StatusUnauthorized, res.Code) + }) + + t.Run("it will return 200 if cookie passed", func(t *testing.T) { + Config.JWTAuthEnabled = true + Config.JWTAuthNoTokenStatusCode = http.StatusUnauthorized + defer func() { + Config.JWTAuthEnabled = false + Config.JWTAuthNoTokenStatusCode = http.StatusTemporaryRedirect + }() + + hh := SetupGlobalMiddleware(h) + res := httptest.NewRecorder() + res.Body = new(bytes.Buffer) + req, _ := http.NewRequest("GET", "http://localhost:18000/api/v1/flags", nil) + req.AddCookie(&http.Cookie{ + Name: "access_token", + Value: validHS256JWTToken, + }) + hh.ServeHTTP(res, req) + assert.Equal(t, http.StatusOK, res.Code) + }) + + t.Run("it will return 200 for some paths", func(t *testing.T) { + Config.JWTAuthEnabled = true + Config.JWTAuthNoTokenStatusCode = http.StatusUnauthorized + defer func() { + Config.JWTAuthEnabled = false + Config.JWTAuthNoTokenStatusCode = http.StatusTemporaryRedirect + }() + + testPaths := []string{"/", "", "/#", "/#/", "/static", "/static/"} + for _, path := range testPaths { + t.Run(fmt.Sprintf("path: %s", path), func(t *testing.T) { + hh := SetupGlobalMiddleware(h) + res := httptest.NewRecorder() + res.Body = new(bytes.Buffer) + req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:18000%s", path), nil) + hh.ServeHTTP(res, req) + assert.Equal(t, http.StatusOK, res.Code) + }) + } + }) +} + func TestStatsMiddleware(t *testing.T) { t.Run("it will setup statsd if statsd is enabled", func(t *testing.T) { Config.StatsdEnabled = true diff --git a/pkg/handler/eval_cache_test.go b/pkg/handler/eval_cache_test.go index c9361bc0..29885fb7 100644 --- a/pkg/handler/eval_cache_test.go +++ b/pkg/handler/eval_cache_test.go @@ -9,17 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestGetEvalCacheStart(t *testing.T) { - db := entity.PopulateTestDB(entity.GenFixtureFlag()) - defer db.Close() - defer gostub.StubFunc(&getDB, db).Reset() - - ec := GetEvalCache() - assert.NotPanics(t, func() { - ec.Start() - }) -} - func TestGetByFlagKeyOrID(t *testing.T) { fixtureFlag := entity.GenFixtureFlag() db := entity.PopulateTestDB(fixtureFlag) @@ -27,7 +16,7 @@ func TestGetByFlagKeyOrID(t *testing.T) { defer gostub.StubFunc(&getDB, db).Reset() ec := GetEvalCache() - ec.Start() + ec.reloadMapCache() f := ec.GetByFlagKeyOrID(fixtureFlag.ID) assert.Equal(t, f.ID, fixtureFlag.ID) }