Skip to content

Commit

Permalink
Merge pull request #163 from checkr/zz/correctly-handle-401-jwt
Browse files Browse the repository at this point in the history
Fix 401 jwt middleware handling
  • Loading branch information
zhouzhuojie authored Sep 13, 2018
2 parents 42ae967 + 323e7ea commit 516dd45
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 28 deletions.
17 changes: 9 additions & 8 deletions pkg/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,15 @@ var Config = struct {
Note:
If the access_token is present in both the header and cookie only the latest will be used
*/
JWTAuthEnabled bool `env:"FLAGR_JWT_AUTH_ENABLED" envDefault:"false"`
JWTAuthDebug bool `env:"FLAGR_JWT_AUTH_DEBUG" envDefault:"false"`
JWTAuthWhitelistPaths []string `env:"FLAGR_JWT_AUTH_WHITELIST_PATHS" envDefault:"/api/v1/evaluation" envSeparator:","`
JWTAuthCookieTokenName string `env:"FLAGR_JWT_AUTH_COOKIE_TOKEN_NAME" envDefault:"access_token"`
JWTAuthSecret string `env:"FLAGR_JWT_AUTH_SECRET" envDefault:""`
JWTAuthNoTokenStatusCode int `env:"FLAGR_JWT_AUTH_NO_TOKEN_STATUS_CODE" envDefault:"307"` // "307" or "401"
JWTAuthNoTokenRedirectURL string `env:"FLAGR_JWT_AUTH_NO_TOKEN_REDIRECT_URL" envDefault:""`
JWTAuthUserProperty string `env:"FLAGR_JWT_AUTH_USER_PROPERTY" envDefault:"flagr_user"`
JWTAuthEnabled bool `env:"FLAGR_JWT_AUTH_ENABLED" envDefault:"false"`
JWTAuthDebug bool `env:"FLAGR_JWT_AUTH_DEBUG" envDefault:"false"`
JWTAuthPrefixWhitelistPaths []string `env:"FLAGR_JWT_AUTH_WHITELIST_PATHS" envDefault:"/api/v1/evaluation,/static" envSeparator:","`
JWTAuthExactWhitelistPaths []string `env:"FLAGR_JWT_AUTH_EXACT_WHITELIST_PATHS" envDefault:",/" envSeparator:","`
JWTAuthCookieTokenName string `env:"FLAGR_JWT_AUTH_COOKIE_TOKEN_NAME" envDefault:"access_token"`
JWTAuthSecret string `env:"FLAGR_JWT_AUTH_SECRET" envDefault:""`
JWTAuthNoTokenStatusCode int `env:"FLAGR_JWT_AUTH_NO_TOKEN_STATUS_CODE" envDefault:"307"` // "307" or "401"
JWTAuthNoTokenRedirectURL string `env:"FLAGR_JWT_AUTH_NO_TOKEN_REDIRECT_URL" envDefault:""`
JWTAuthUserProperty string `env:"FLAGR_JWT_AUTH_USER_PROPERTY" envDefault:"flagr_user"`

// "HS256" and "RS256" supported
JWTAuthSigningMethod string `env:"FLAGR_JWT_AUTH_SIGNING_METHOD" envDefault:"HS256"`
Expand Down
33 changes: 26 additions & 7 deletions pkg/config/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ func setupJWTAuthMiddleware() *auth {
}

return &auth{
WhitelistPaths: Config.JWTAuthWhitelistPaths,
PrefixWhitelistPaths: Config.JWTAuthPrefixWhitelistPaths,
ExactWhitelistPaths: Config.JWTAuthExactWhitelistPaths,
JWTMiddleware: jwtmiddleware.New(jwtmiddleware.Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
return validationKey, errParsingKey
Expand Down Expand Up @@ -116,18 +117,36 @@ func jwtErrorHandler(w http.ResponseWriter, r *http.Request, err string) {
}

type auth struct {
WhitelistPaths []string
JWTMiddleware *jwtmiddleware.JWTMiddleware
PrefixWhitelistPaths []string
ExactWhitelistPaths []string
JWTMiddleware *jwtmiddleware.JWTMiddleware
}

func (a *auth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
func (a *auth) whitelist(req *http.Request) bool {
path := req.URL.Path
for _, p := range a.WhitelistPaths {

// If we set to 401 unauthorized, let the client handles the 401 itself
if Config.JWTAuthNoTokenStatusCode == http.StatusUnauthorized {
for _, p := range a.ExactWhitelistPaths {
if p == path {
return true
}
}
}

for _, p := range a.PrefixWhitelistPaths {
if p != "" && strings.HasPrefix(path, p) {
next(w, req)
return
return true
}
}
return false
}

func (a *auth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
if a.whitelist(req) {
next(w, req)
return
}
a.JWTMiddleware.HandlerWithNext(w, req, next)
}

Expand Down
63 changes: 62 additions & 1 deletion pkg/config/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestAuthMiddleware(t *testing.T) {

res := httptest.NewRecorder()
res.Body = new(bytes.Buffer)
req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:18000%s", Config.JWTAuthWhitelistPaths[0]), nil)
req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:18000%s", Config.JWTAuthPrefixWhitelistPaths[0]), nil)
hh.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
})
Expand Down Expand Up @@ -233,6 +233,67 @@ o2kQ+X5xK9cipRgEKwIDAQAB
})
}

func TestAuthMiddlewareWithUnauthorized(t *testing.T) {
h := &okHandler{}

t.Run("it will return 401 if no cookie passed", func(t *testing.T) {
Config.JWTAuthEnabled = true
Config.JWTAuthNoTokenStatusCode = http.StatusUnauthorized
defer func() {
Config.JWTAuthEnabled = false
Config.JWTAuthNoTokenStatusCode = http.StatusTemporaryRedirect
}()

hh := SetupGlobalMiddleware(h)
res := httptest.NewRecorder()
res.Body = new(bytes.Buffer)
req, _ := http.NewRequest("GET", "http://localhost:18000/api/v1/flags", nil)
hh.ServeHTTP(res, req)
assert.Equal(t, http.StatusUnauthorized, res.Code)
})

t.Run("it will return 200 if cookie passed", func(t *testing.T) {
Config.JWTAuthEnabled = true
Config.JWTAuthNoTokenStatusCode = http.StatusUnauthorized
defer func() {
Config.JWTAuthEnabled = false
Config.JWTAuthNoTokenStatusCode = http.StatusTemporaryRedirect
}()

hh := SetupGlobalMiddleware(h)
res := httptest.NewRecorder()
res.Body = new(bytes.Buffer)
req, _ := http.NewRequest("GET", "http://localhost:18000/api/v1/flags", nil)
req.AddCookie(&http.Cookie{
Name: "access_token",
Value: validHS256JWTToken,
})
hh.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
})

t.Run("it will return 200 for some paths", func(t *testing.T) {
Config.JWTAuthEnabled = true
Config.JWTAuthNoTokenStatusCode = http.StatusUnauthorized
defer func() {
Config.JWTAuthEnabled = false
Config.JWTAuthNoTokenStatusCode = http.StatusTemporaryRedirect
}()

testPaths := []string{"/", "", "/#", "/#/", "/static", "/static/"}
for _, path := range testPaths {
t.Run(fmt.Sprintf("path: %s", path), func(t *testing.T) {
hh := SetupGlobalMiddleware(h)
res := httptest.NewRecorder()
res.Body = new(bytes.Buffer)
req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:18000%s", path), nil)
hh.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
})
}
})
}

func TestStatsMiddleware(t *testing.T) {
t.Run("it will setup statsd if statsd is enabled", func(t *testing.T) {
Config.StatsdEnabled = true
Expand Down
13 changes: 1 addition & 12 deletions pkg/handler/eval_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,14 @@ import (
"github.com/stretchr/testify/assert"
)

func TestGetEvalCacheStart(t *testing.T) {
db := entity.PopulateTestDB(entity.GenFixtureFlag())
defer db.Close()
defer gostub.StubFunc(&getDB, db).Reset()

ec := GetEvalCache()
assert.NotPanics(t, func() {
ec.Start()
})
}

func TestGetByFlagKeyOrID(t *testing.T) {
fixtureFlag := entity.GenFixtureFlag()
db := entity.PopulateTestDB(fixtureFlag)
defer db.Close()
defer gostub.StubFunc(&getDB, db).Reset()

ec := GetEvalCache()
ec.Start()
ec.reloadMapCache()
f := ec.GetByFlagKeyOrID(fixtureFlag.ID)
assert.Equal(t, f.ID, fixtureFlag.ID)
}

0 comments on commit 516dd45

Please sign in to comment.