Skip to content

Commit

Permalink
updated and fixed test cases for jwt middleware + sonar and linting f…
Browse files Browse the repository at this point in the history
…ixes
  • Loading branch information
himanshu-allen committed Oct 15, 2024
1 parent 4cb2f0c commit 380b101
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 25 deletions.
6 changes: 1 addition & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ ENV FLAGR_JWT_AUTH_DEBUG=true
ENV FLAGR_JWT_AUTH_WHITELIST_PATHS="/api/v1/health,/api/v1/evaluation,/login,/callback,/static,/favicon.ico,/flags"
ENV FLAGR_JWT_AUTH_EXACT_WHITELIST_PATHS=",/,/login,/callback"
ENV FLAGR_JWT_AUTH_COOKIE_TOKEN_NAME="access_token"
ENV FLAGR_JWT_AUTH_SECRET="01d3024af90f4100c22d20eb294eee46e8fa286b53a4b08aa16cadf7d4b70b0935778af1228caf115caeb62873f789b7af6f2e6591272cec784d6fc68a91e9f8538a223b8aa622b594cfac2a01ef15fb583d5adadea3174be6cc91d0db638574997334df095427b3319a3937d10c07adc2e4a95669bd1fd75807dd02bca06432"
#ENV FLAGR_JWT_AUTH_SECRET="<your-jwt-secret>"
ENV FLAGR_JWT_AUTH_NO_TOKEN_STATUS_CODE=307
ENV FLAGR_JWT_AUTH_NO_TOKEN_REDIRECT_URL="http://localhost:3000/login"
ENV FLAGR_JWT_AUTH_USER_PROPERTY=flagr_user
ENV FLAGR_JWT_AUTH_USER_CLAIM=uid
ENV FLAGR_JWT_AUTH_SIGNING_METHOD=HS256

# ENV FLAGR_BASIC_AUTH_ENABLED=true
# ENV FLAGR_BASIC_AUTH_USERNAME=admin
# ENV FLAGR_BASIC_AUTH_PASSWORD=password

COPY --from=npm_builder /go/src/github.com/openflagr/flagr/browser/flagr-ui/dist ./browser/flagr-ui/dist

RUN addgroup -S appgroup && adduser -S appuser -G appgroup
Expand Down
8 changes: 6 additions & 2 deletions pkg/config/jwtmiddleware/jwt_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"context"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"log"
"net/http"
"strings"

"github.com/golang-jwt/jwt/v5"
)

// A function called whenever an error is encountered
Expand All @@ -20,6 +21,9 @@ type errorHandler func(w http.ResponseWriter, r *http.Request, err string)
// be treated as an error. An empty string should be returned in that case.
type TokenExtractor func(r *http.Request) (string, error)

// Define a custom type for the key
type contextKey string

// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
Expand Down Expand Up @@ -229,7 +233,7 @@ func (m *JWTMiddleware) CheckJWT(w http.ResponseWriter, r *http.Request) error {

// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, parsedToken))
newRequest := r.WithContext(context.WithValue(r.Context(), contextKey(m.Options.UserProperty), parsedToken))
// Update the current request with the new context information.
*r = *newRequest
return nil
Expand Down
135 changes: 126 additions & 9 deletions pkg/config/jwtmiddleware/jwt_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package jwtmiddleware
import (
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -20,7 +20,7 @@ import (
const defaultAuthorizationHeaderName = "Authorization"

// userPropertyName is the property name that will be set in the request context
const userPropertyName = "custom-user-property"
const userPropertyName = "user"

// the bytes read from the keys/sample-key file
// private key generated with http://kjur.github.io/jsjws/tool_jwt.html
Expand Down Expand Up @@ -56,7 +56,7 @@ func TestAuthenticatedRequest(t *testing.T) {
var expectedAlgorithm jwt.SigningMethod = nil
w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm)
So(w.Code, ShouldEqual, http.StatusOK)
responseBytes, err := ioutil.ReadAll(w.Body)
responseBytes, err := io.ReadAll(w.Body)
if err != nil {
panic(err)
}
Expand All @@ -68,7 +68,7 @@ func TestAuthenticatedRequest(t *testing.T) {
expectedAlgorithm := jwt.SigningMethodHS256
w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm)
So(w.Code, ShouldEqual, http.StatusOK)
responseBytes, err := ioutil.ReadAll(w.Body)
responseBytes, err := io.ReadAll(w.Body)
if err != nil {
panic(err)
}
Expand All @@ -80,7 +80,7 @@ func TestAuthenticatedRequest(t *testing.T) {
expectedAlgorithm := jwt.SigningMethodRS256
w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm)
So(w.Code, ShouldEqual, http.StatusUnauthorized)
responseBytes, err := ioutil.ReadAll(w.Body)
responseBytes, err := io.ReadAll(w.Body)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func JWT(expectedSignatureAlgorithm jwt.SigningMethod) *JWTMiddleware {
var err error
privateKey, err = readPrivateKey()
if err != nil {
panic(err)
return nil, err
}
}
return privateKey, nil
Expand All @@ -193,9 +193,25 @@ func indexHandler(w http.ResponseWriter, r *http.Request) {
// in the token as json -> {"text":"bar"}
func protectedHandler(w http.ResponseWriter, r *http.Request) {
// retrieve the token from the context
u := r.Context().Value(userPropertyName)
user := u.(*jwt.Token)
respondJSON(user.Claims.(jwt.MapClaims)["foo"].(string), w)
u := r.Context().Value(contextKey(userPropertyName))
if u == nil {
http.Error(w, "Unauthorized: no token present", http.StatusUnauthorized)
return
}

user, ok := u.(*jwt.Token)
if !ok {
http.Error(w, "Unauthorized: invalid token type", http.StatusUnauthorized)
return
}

claims, ok := user.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Unauthorized: invalid claims type", http.StatusUnauthorized)
return
}

respondJSON(claims["foo"].(string), w)
}

// Response quick n' dirty Response struct to be encoded as json
Expand All @@ -215,3 +231,104 @@ func respondJSON(text string, w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonResponse)
}

func TestJWTMiddleware_Handler(t *testing.T) {
// Define a mock handler to be wrapped
mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

type fields struct {
Options Options
}
type args struct {
h http.Handler
}
tests := []struct {
name string
fields fields
args args
setupJWTCheck func(w *httptest.ResponseRecorder, r *http.Request) error
wantStatus int
}{
{
name: "Valid JWT",
fields: fields{
Options: Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Return a valid key for testing
return []byte("valid-signing-key"), nil
},
SigningMethod: jwt.SigningMethodHS256,
UserProperty: "user",
Extractor: FromAuthHeader,
},
},
args: args{
h: mockHandler,
},
setupJWTCheck: func(w *httptest.ResponseRecorder, r *http.Request) error {
// Create a valid JWT token
token := jwt.New(jwt.SigningMethodHS256)
tokenString, err := token.SignedString([]byte("valid-signing-key"))
if err != nil {
return err
}
// Add the JWT token to the request's Authorization header
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokenString))
return nil
},
wantStatus: http.StatusOK,
},
{
name: "Invalid JWT",
fields: fields{
Options: Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Return a valid key for testing
return []byte("valid-signing-key"), nil
},
Extractor: FromAuthHeader,
SigningMethod: jwt.SigningMethodHS256,
UserProperty: "user",
ErrorHandler: OnError,
},
},
args: args{
h: mockHandler,
},
setupJWTCheck: func(w *httptest.ResponseRecorder, r *http.Request) error {
// Add an invalid JWT token to the request's Authorization header
r.Header.Set("Authorization", "Bearer invalidtoken")
return nil
},
wantStatus: http.StatusUnauthorized,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a new JWTMiddleware instance
m := &JWTMiddleware{
Options: tt.fields.Options,
}

// Create a new HTTP test recorder and request
recorder := httptest.NewRecorder()
request, _ := http.NewRequest("GET", "/", nil)

// Apply the JWT setup (valid/invalid token)
if err := tt.setupJWTCheck(recorder, request); err != nil {
t.Fatalf("failed to set up JWT check: %v", err)
}

// Invoke the middleware handler
m.Handler(tt.args.h).ServeHTTP(recorder, request)

// Check the status code
if recorder.Code != tt.wantStatus {
t.Errorf("Handler() status = %v, want %v", recorder.Code, tt.wantStatus)
}
})
}
}
6 changes: 3 additions & 3 deletions pkg/config/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ package config
import (
"crypto/subtle"
"fmt"
"os"
"path/filepath"

"github.com/openflagr/flagr/pkg/config/jwtmiddleware"

"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -118,7 +118,7 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler {
}

// Otherwise, serve index.html for Vue.js routing
http.ServeFile(w, r, "./browser/flagr-ui/dist/index.html")
http.ServeFile(w, r, "../../browser/flagr-ui/dist/index.html")
})

n.Use(setupRecoveryMiddleware())
Expand Down
1 change: 1 addition & 0 deletions pkg/config/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// nolint: errcheck
package config

import (
Expand Down
1 change: 1 addition & 0 deletions pkg/entity/db.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// nolint: errcheck
package entity

import (
Expand Down
1 change: 1 addition & 0 deletions pkg/entity/fixture.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// nolint: errcheck
package entity

import (
Expand Down
1 change: 1 addition & 0 deletions pkg/entity/flag_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package entity

import (
"fmt"

jsoniter "github.com/json-iterator/go"
"github.com/openflagr/flagr/pkg/config"
"github.com/openflagr/flagr/pkg/util"
Expand Down
1 change: 1 addition & 0 deletions pkg/entity/variant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package entity
import (
"database/sql/driver"
"fmt"

jsoniter "github.com/json-iterator/go"

"github.com/openflagr/flagr/pkg/util"
Expand Down
1 change: 0 additions & 1 deletion pkg/handler/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ func (c *crud) PutFlag(params flag.PutFlagParams) middleware.Responder {
}
resp.SetPayload(payload)

fmt.Printf("Subject: %s\n", getSubjectFromRequest(params.HTTPRequest))
entity.SaveFlagSnapshot(getDB(), util.SafeUint(params.FlagID), getSubjectFromRequest(params.HTTPRequest))
return resp
}
Expand Down
6 changes: 1 addition & 5 deletions pkg/handler/subject.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handler

import (
"fmt"
"net/http"

"github.com/openflagr/flagr/pkg/config"
Expand All @@ -21,10 +20,7 @@ func getSubjectFromRequest(r *http.Request) string {
return ""
}

fmt.Printf("Token: %+v\n", token)
claims, ok := token.Claims.(jwt.MapClaims)

if ok && token.Valid {
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
return util.SafeString(claims[config.Config.JWTAuthUserClaim])
}

Expand Down

0 comments on commit 380b101

Please sign in to comment.