From ddb3eab3cd3e30fa5cb24af47bbc538b2d8ff067 Mon Sep 17 00:00:00 2001 From: Mario Hros Date: Wed, 6 May 2020 20:26:50 +0200 Subject: [PATCH 1/9] formatting, linting, documenting - without functional changes cherry-pick 7e59a13 --- cmd/main.go | 8 ++- internal/authentication/auth.go | 62 +++++++++++------- internal/authentication/auth_test.go | 4 +- internal/authorization/authorizer.go | 1 + internal/authorization/rbac/rbac.go | 58 +++++++++++------ internal/authorization/rbac/rbac_test.go | 8 +-- internal/authorization/user.go | 3 + internal/authorization/util.go | 1 + internal/configuration/config.go | 17 +++-- internal/handlers/server.go | 80 ++++++++++++++++-------- internal/util/cookiedomain.go | 8 ++- 11 files changed, 167 insertions(+), 83 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 2f71028..5228ec3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -29,7 +29,9 @@ func main() { config.Validate() // Query the OIDC provider - config.SetOidcProvider() + if err := config.LoadOIDCProviderConfiguration(); err != nil { + log.Fatalln(err.Error()) + } authenticator := authentication.NewAuthenticator(config) // Get clientset for Authorizers @@ -77,7 +79,7 @@ func main() { http.HandleFunc("/", server.RootHandler) // Start - log.Debugf("Starting with options: %s", config) - log.Info("Listening on :4181") + log.Debugf("starting with options: %s", config) + log.Info("listening on :4181") log.Info(http.ListenAndServe(":4181", nil)) } diff --git a/internal/authentication/auth.go b/internal/authentication/auth.go index f024dd6..1dee77e 100644 --- a/internal/authentication/auth.go +++ b/internal/authentication/auth.go @@ -25,7 +25,8 @@ func NewAuthenticator(config *configuration.Config) *Authenticator { // Request Validation -// Cookie = hash(secret, cookie domain, email, expires)|expires|email|groups +// ValidateCookie validates the ID cookie in the request +// IDCookie = hash(secret, cookie domain, email, expires)|expires|email|group func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { parts := strings.Split(c.Value, "|") @@ -63,7 +64,8 @@ func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (string, return parts[2], nil } -// Validate email +// ValidateEmail validates that the provided email ends with one of the configured Domains or is part of the configured Whitelist. +// Also returns true if there is no Whitelist and no Domains configured. func (a *Authenticator) ValidateEmail(email string) bool { if len(a.config.Whitelist) > 0 || len(a.config.Domains) > 0 { for _, whitelist := range a.config.Whitelist { @@ -83,17 +85,20 @@ func (a *Authenticator) ValidateEmail(email string) bool { return true } -// Get oauth redirect uri -func (a *Authenticator) RedirectUri(r *http.Request) string { +// ComposeRedirectURI generates oauth redirect uri to return to from the OAuth2 provider +func (a *Authenticator) ComposeRedirectURI(r *http.Request) string { if use, _ := a.useAuthDomain(r); use { - proto := r.Header.Get("X-Forwarded-Proto") - return fmt.Sprintf("%s://%s%s", proto, a.config.AuthHost, a.config.Path) + scheme := r.Header.Get("X-Forwarded-Proto") + return fmt.Sprintf("%s://%s%s", scheme, a.config.AuthHost, a.config.Path) + } - return fmt.Sprintf("%s%s", redirectBase(r), a.config.Path) + return fmt.Sprintf("%s%s", getRequestSchemeHost(r), a.config.Path) } -// Should we use auth host + what it is +// useAuthDomain decides whether the host of the forwarded request +// matches the configured AuthHost and whether we can configure cookies for the AuthHost +// If it does, the function returns true and the top-level domain from the config we can use func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) { if a.config.AuthHost == "" { return false, "" @@ -111,7 +116,7 @@ func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) { // Cookie methods -// Create an auth cookie +// MakeIDCookie creates an auth cookie func (a *Authenticator) MakeIDCookie(r *http.Request, email string) *http.Cookie { expires := a.cookieExpiry() mac := a.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix())) @@ -128,7 +133,7 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string) *http.Cookie } } -// Create a name cookie +// MakeNameCookie creates a name cookie func (a *Authenticator) MakeNameCookie(r *http.Request, name string) *http.Cookie { expires := a.cookieExpiry() @@ -143,7 +148,7 @@ func (a *Authenticator) MakeNameCookie(r *http.Request, name string) *http.Cooki } } -// Make a CSRF cookie (used during login only) +// MakeCSRFCookie creates a CSRF cookie (used during login only) func (a *Authenticator) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ Name: a.config.CSRFCookieName, @@ -156,7 +161,7 @@ func (a *Authenticator) MakeCSRFCookie(r *http.Request, nonce string) *http.Cook } } -// Create a cookie to clear csrf cookie +// ClearCSRFCookie clears the csrf cookie func (a *Authenticator) ClearCSRFCookie(r *http.Request) *http.Cookie { return &http.Cookie{ Name: a.config.CSRFCookieName, @@ -169,7 +174,7 @@ func (a *Authenticator) ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// Validate the csrf cookie against state +// ValidateCSRFCookie validates the csrf cookie against state func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) { state := r.URL.Query().Get("state") @@ -190,15 +195,16 @@ func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) { return true, state[33:], nil } -func Nonce() (error, string) { +// GenerateNonce generates a random nonce string +func GenerateNonce() (string, error) { // Make nonce nonce := make([]byte, 16) _, err := rand.Read(nonce) if err != nil { - return err, "" + return "", err } - return nil, fmt.Sprintf("%x", nonce) + return fmt.Sprintf("%x", nonce), nil } // Cookie domain @@ -224,7 +230,10 @@ func (a *Authenticator) csrfCookieDomain(r *http.Request) string { return p[0] } -// Return matching cookie domain if exists +// matchCookieDomains checks if the provided domain maches any domain configured in the CookieDomains list +// and returns the domain from the list it matched with. +// The match is either the direct equality of domain names or the input subdomain (e.g. "a.test.com") belongs under a configured top domain ("test.com"). +// If the domain does not match CookieDomains, false is returned with the input domain as the second return value. func (a *Authenticator) matchCookieDomains(domain string) (bool, string) { // Remove port p := strings.Split(domain, ":") @@ -240,7 +249,7 @@ func (a *Authenticator) matchCookieDomains(domain string) (bool, string) { return false, p[0] } -// Create cookie hmac +// cookieSignature creates a cookie hmac func (a *Authenticator) cookieSignature(r *http.Request, email, expires string) string { hash := hmac.New(sha256.New, a.config.Secret) hash.Write([]byte(a.GetCookieDomain(r))) @@ -256,21 +265,26 @@ func (a *Authenticator) cookieExpiry() time.Time { // Utility methods -// Get the redirect base -func redirectBase(r *http.Request) string { +// getRequestSchemeHost returns scheme://host part of the request +// Example output: "https://domain.com" +func getRequestSchemeHost(r *http.Request) string { proto := r.Header.Get("X-Forwarded-Proto") host := r.Header.Get("X-Forwarded-Host") return fmt.Sprintf("%s://%s", proto, host) } -func GetUriPath(r *http.Request) string { +// GetRequestURI returns the full request URI with query parameters. +// The path includes the prefix (if stripPrefix middleware was used). +// Example output: "/prefix/path?query=1" +func GetRequestURI(r *http.Request) string { prefix := r.Header.Get("X-Forwarded-Prefix") uri := r.Header.Get("X-Forwarded-Uri") return fmt.Sprintf("%s/%s", strings.TrimRight(prefix, "/"), strings.TrimLeft(uri, "/")) } -// // Return url -func ReturnUrl(r *http.Request) string { - return fmt.Sprintf("%s%s", redirectBase(r), GetUriPath(r)) +// GetRequestURL returns full requst URL scheme://host/uri with query params +// Example output: "https://domain.com/prefix/path?query=1" +func GetRequestURL(r *http.Request) string { + return fmt.Sprintf("%s%s", getRequestSchemeHost(r), GetRequestURI(r)) } diff --git a/internal/authentication/auth_test.go b/internal/authentication/auth_test.go index a21a5d4..7c4d21c 100644 --- a/internal/authentication/auth_test.go +++ b/internal/authentication/auth_test.go @@ -220,11 +220,11 @@ func TestAuthValidateCSRFCookie(t *testing.T) { func TestAuthNonce(t *testing.T) { assert := assert.New(t) - err, nonce1 := Nonce() + nonce1, err := GenerateNonce() assert.Nil(err, "error generating nonce") assert.Len(nonce1, 32, "length should be 32 chars") - err, nonce2 := Nonce() + nonce2, err := GenerateNonce() assert.Nil(err, "error generating nonce") assert.Len(nonce2, 32, "length should be 32 chars") diff --git a/internal/authorization/authorizer.go b/internal/authorization/authorizer.go index 47bf887..a3de839 100644 --- a/internal/authorization/authorizer.go +++ b/internal/authorization/authorizer.go @@ -1,5 +1,6 @@ package authorization +// Authorizer is the interface for implementing user authorization (check to see if the user can perform the action) type Authorizer interface { Authorize(user User, requestVerb, requestResource string) (bool, error) } diff --git a/internal/authorization/rbac/rbac.go b/internal/authorization/rbac/rbac.go index adbc200..ddd9e28 100644 --- a/internal/authorization/rbac/rbac.go +++ b/internal/authorization/rbac/rbac.go @@ -19,7 +19,8 @@ const ( cacheSyncDuration = time.Minute * 10 ) -type RBACAuthorizer struct { +// Authorizer implements the authorizer by watching and using ClusterRole and ClusterRoleBinding Kubernetes (RBAC) objects +type Authorizer struct { clientset kubernetes.Interface clusterRoleLister rbaclisterv1.ClusterRoleLister clusterRoleBindingLister rbaclisterv1.ClusterRoleBindingLister @@ -29,8 +30,9 @@ type RBACAuthorizer struct { selector labels.Selector } -func NewRBACAuthorizer(clientset kubernetes.Interface) *RBACAuthorizer { - authz := &RBACAuthorizer{ +// NewAuthorizer creates a new RBAC authorizer +func NewAuthorizer(clientset kubernetes.Interface) *Authorizer { + authz := &Authorizer{ clientset: clientset, syncDuration: cacheSyncDuration, selector: labels.NewSelector(), @@ -41,7 +43,9 @@ func NewRBACAuthorizer(clientset kubernetes.Interface) *RBACAuthorizer { } // Private -func (ra *RBACAuthorizer) getRoleByName(name string) *rbacv1.ClusterRole { + +// getRoleByName finds the ClusterRole by its name or returns nil +func (ra *Authorizer) getRoleByName(name string) *rbacv1.ClusterRole { clusterRole, err := ra.clusterRoleLister.Get(name) if err != nil { if errors.IsNotFound(err) { @@ -59,25 +63,32 @@ func (ra *RBACAuthorizer) getRoleByName(name string) *rbacv1.ClusterRole { return clusterRole } -func (ra *RBACAuthorizer) getRoleFromGroups(target, role string, groups []string) *rbacv1.ClusterRole { - for _, group := range groups { - if group == target { - return ra.getRoleByName(role) +// getRoleFromGroups returns role specified in roleNameRef only if subjectGroupName is in the userGroups list +func (ra *Authorizer) getRoleFromGroups(roleNameRef, subjectGroupName string, userGroups []string) *rbacv1.ClusterRole { + // for every user group... + for _, group := range userGroups { + // if the group matches the group name in the subject, return the role + if group == subjectGroupName { + return ra.getRoleByName(roleNameRef) } } + + // no user group match this subjectGroupName return nil } -func (ra *RBACAuthorizer) getRoleForSubject(user authorization.User, subject rbacv1.Subject, role string) *rbacv1.ClusterRole { +// getRoleForSubject gets the role bound to the subject depending on the subject kind (user or group). +// Returns nil if there is no rule matching or an unknown subject Kind is provided +func (ra *Authorizer) getRoleForSubject(user authorization.User, subject rbacv1.Subject, roleNameRef string) *rbacv1.ClusterRole { if subject.Kind == "User" && subject.Name == user.GetName() { - return ra.getRoleByName(role) + return ra.getRoleByName(roleNameRef) } else if subject.Kind == "Group" { - return ra.getRoleFromGroups(subject.Name, role, user.GetGroups()) + return ra.getRoleFromGroups(roleNameRef, subject.Name, user.GetGroups()) } return nil } -func (ra *RBACAuthorizer) prepareCache() { +func (ra *Authorizer) prepareCache() { ra.sharedInformerFactory = informers.NewSharedInformerFactory(ra.clientset, ra.syncDuration) ra.clusterRoleLister = ra.sharedInformerFactory.Rbac().V1().ClusterRoles().Lister() ra.clusterRoleBindingLister = ra.sharedInformerFactory.Rbac().V1().ClusterRoleBindings().Lister() @@ -86,7 +97,9 @@ func (ra *RBACAuthorizer) prepareCache() { } // Public -func (ra *RBACAuthorizer) GetRoles(user authorization.User) (*rbacv1.ClusterRoleList, error) { + +// GetRolesBoundToUser returns list of roles bound to the specified user or groups the user is part of +func (ra *Authorizer) GetRolesBoundToUser(user authorization.User) (*rbacv1.ClusterRoleList, error) { clusterRoles := rbacv1.ClusterRoleList{} clusterRoleBindings, err := ra.clusterRoleBindingLister.List(ra.selector) if err != nil { @@ -105,29 +118,37 @@ func (ra *RBACAuthorizer) GetRoles(user authorization.User) (*rbacv1.ClusterRole } // Interface methods -func (ra *RBACAuthorizer) Authorize(user authorization.User, requestVerb, requestResource string) (bool, error) { - roles, err := ra.GetRoles(user) + +// Authorize performs the authorization logic +func (ra *Authorizer) Authorize(user authorization.User, requestVerb, requestResource string) (bool, error) { + roles, err := ra.GetRolesBoundToUser(user) if err != nil { return false, err } + // deny if no roles defined if len(roles.Items) < 1 { return false, nil } + // check all rules in the list of roles to see if any matches for _, role := range roles.Items { for _, rule := range role.Rules { - if VerbMatches(&rule, requestVerb) && NonResourceURLMatches(&rule, requestResource) { + if verbMatches(&rule, requestVerb) && nonResourceURLMatches(&rule, requestResource) { return true, nil } } } + // no rules match the request -> deny return false, nil } // Utility -func VerbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool { + +// verbMatches returns true if the requested verb matches a verb specifid in the rule +// Also matches if the rule mentiones special "all verbs" rule * +func verbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool { for _, ruleVerb := range rule.Verbs { if ruleVerb == rbacv1.VerbAll { return true @@ -140,7 +161,8 @@ func VerbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool { return false } -func NonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL string) bool { +// nonResourceURLMatches returns true if the requested URL matches a policy the rule +func nonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL string) bool { for _, ruleURL := range rule.NonResourceURLs { if ruleURL == rbacv1.NonResourceAll { return true diff --git a/internal/authorization/rbac/rbac_test.go b/internal/authorization/rbac/rbac_test.go index 4aa5297..ab5f6a9 100644 --- a/internal/authorization/rbac/rbac_test.go +++ b/internal/authorization/rbac/rbac_test.go @@ -24,8 +24,8 @@ type testCase struct { should bool } -func getRBACAuthorizer(objs ...runtime.Object) *RBACAuthorizer { - return NewRBACAuthorizer(fake.NewSimpleClientset(objs...)) +func getRBACAuthorizer(objs ...runtime.Object) *Authorizer { + return NewAuthorizer(fake.NewSimpleClientset(objs...)) } func makeRole(name string, verbs, urls []string) rbacv1.ClusterRole { @@ -98,7 +98,7 @@ func TestRBACAuthorizer_GetRoles(t *testing.T) { a := getRBACAuthorizer(roles, bindings) u1 := authorization.User{Name: "u1"} - r, err := a.GetRoles(u1) + r, err := a.GetRolesBoundToUser(u1) assert.NilError(t, err) assert.Equal(t, len(r.Items), 2) @@ -107,7 +107,7 @@ func TestRBACAuthorizer_GetRoles(t *testing.T) { u2 := authorization.User{Name: "u2", Groups: []string{"g1", "g2"}} - r, err = a.GetRoles(u2) + r, err = a.GetRolesBoundToUser(u2) assert.NilError(t, err) assert.Equal(t, len(r.Items), 1) assert.Equal(t, r.Items[0].Name, "r3") diff --git a/internal/authorization/user.go b/internal/authorization/user.go index 3f9d3be..9f0b689 100644 --- a/internal/authorization/user.go +++ b/internal/authorization/user.go @@ -1,14 +1,17 @@ package authorization +// User represents an autorized user type User struct { Name string Groups []string } +// GetName returns the user name func (k *User) GetName() string { return k.Name } +// GetGroups return list of groups the user belongs to func (k *User) GetGroups() []string { return k.Groups } diff --git a/internal/authorization/util.go b/internal/authorization/util.go index f415b2e..80e6e29 100644 --- a/internal/authorization/util.go +++ b/internal/authorization/util.go @@ -4,6 +4,7 @@ import ( "strings" ) +// PathMatches returns true if the URL matches the pattern containing an optional wildcard '*' character func PathMatches(url, pattern string) bool { return pattern == url || (strings.HasSuffix(pattern, "*") && strings.HasPrefix(url, strings.TrimRight(pattern, "*"))) diff --git a/internal/configuration/config.go b/internal/configuration/config.go index f15d2a8..4766d6c 100644 --- a/internal/configuration/config.go +++ b/internal/configuration/config.go @@ -29,12 +29,13 @@ var ( log logrus.FieldLogger ) +// Config holds app configuration type Config struct { LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"` LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"` - ProviderUri string `long:"provider-uri" env:"PROVIDER_URI" description:"OIDC Provider URI"` - ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"` + ProviderURI string `long:"provider-uri" env:"PROVIDER_URI" description:"OIDC Provider URI"` + ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"` ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"` Scope string `long:"scope" env:"SCOPE" description:"Define scope"` AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"` @@ -209,6 +210,7 @@ func convertLegacyToIni(name string) (io.Reader, error) { return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil } +// Validate validates the provided config func (c *Config) Validate() { // Check for show stopper errors if len(c.SecretString) == 0 { @@ -217,7 +219,7 @@ func (c *Config) Validate() { log.Infoln("for better security, \"secret\" should ideally be 32 bytes or longer") } - if c.ProviderUri == "" || c.ClientId == "" || c.ClientSecret == "" { + if c.ProviderURI == "" || c.ClientID == "" || c.ClientSecret == "" { log.Fatal("provider-uri, client-id, client-secret must be set") } @@ -250,13 +252,16 @@ func (c *Config) Validate() { } } -func (c *Config) SetOidcProvider() { +// LoadOIDCProviderConfiguration loads the configuration of OpenID Connect provider +func (c *Config) LoadOIDCProviderConfiguration() error { // Fetch OIDC Provider configuration - provider, err := oidc.NewProvider(c.OIDCContext, c.ProviderUri) + provider, err := oidc.NewProvider(c.OIDCContext, c.ProviderURI) if err != nil { - log.Fatalf("failed to get provider configuration for %s: %v (hint: make sure %s is accessible from the cluster)", c.ProviderUri, err, c.ProviderUri) + return fmt.Errorf("failed to get provider configuration for %s: %v (hint: make sure %s is accessible from the cluster)", + c.ProviderURI, err, c.ProviderURI) } c.OIDCProvider = provider + return nil } func (c Config) String() string { diff --git a/internal/handlers/server.go b/internal/handlers/server.go index f52720c..6555d0c 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -26,6 +26,7 @@ const ( impersonateGroupHeader = "Impersonate-Group" ) +// Server implements the HTTP server handling forwardauth type Server struct { router *rules.Router userinfo v1alpha1.UserInfoInterface @@ -35,6 +36,7 @@ type Server struct { authenticator *authentication.Authenticator } +// NewServer creates a new forwardauth server func NewServer(userinfo v1alpha1.UserInfoInterface, clientset kubernetes.Interface, config *configuration.Config) *Server { s := &Server{ log: internallog.NewDefaultLogger(config.LogLevel, config.LogFormat), @@ -46,7 +48,7 @@ func NewServer(userinfo v1alpha1.UserInfoInterface, clientset kubernetes.Interfa s.buildRoutes() s.userinfo = userinfo if config.EnableRBAC { - s.authorizer = rbac.NewRBACAuthorizer(clientset) + s.authorizer = rbac.NewAuthorizer(clientset) } return s } @@ -82,6 +84,7 @@ func (s *Server) buildRoutes() { } } +// RootHandler it the main handler (for / path) func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { logger := s.log.WithFields(logrus.Fields{ "X-Forwarded-Method": r.Header.Get("X-Forwarded-Method"), @@ -94,7 +97,7 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { // Modify request r.Method = r.Header.Get("X-Forwarded-Method") r.Host = r.Header.Get("X-Forwarded-Host") - r.URL, _ = neturl.Parse(authentication.GetUriPath(r)) + r.URL, _ = neturl.Parse(authentication.GetRequestURI(r)) if s.config.AuthHost == "" || len(s.config.CookieDomains) > 0 || r.Host == s.config.AuthHost { s.router.ServeHTTP(w, r) @@ -103,12 +106,12 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { url := r.URL url.Scheme = r.Header.Get("X-Forwarded-Proto") url.Host = s.config.AuthHost - logger.Debugf("Redirect to %v", url.String()) + logger.Debugf("redirect to %v", url.String()) http.Redirect(w, r, url.String(), 307) } } -// Handler that allows requests +// AllowHandler handles the request as implicite "allow", returining HTTP 200 response to the Traefik func (s *Server) AllowHandler(rule string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { s.logger(r, rule, "Allow request") @@ -116,7 +119,9 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc { } } -// Authenticate requests +// AuthHandler handles the request as requiring authentication. +// It validates the existing session, starting a new auth flow if the session is not valid. +// Finally it also performs authorization (if enabled) to ensure the logged-in subject is authorized to perform the request. func (s *Server) AuthHandler(rule string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Logging setup @@ -162,7 +167,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } if s.config.EnableRBAC && !s.authzIsBypassed(r) { - kubeUserInfo := s.getModifiedUserInfo(email, groups) + kubeUserInfo := s.makeKubeUserInfo(email, groups) logger.Debugf("authorizing user: %s, groups: %s", kubeUserInfo.Name, kubeUserInfo.Groups) authorized, err := s.authorizer.Authorize(kubeUserInfo, r.Method, r.URL.Path) @@ -223,7 +228,8 @@ func cleanupConnectionHeader(original string) string { return strings.TrimSpace(strings.Join(passThrough, ",")) } -// Handle auth callback +// AuthCallbackHandler handles the request as a callback from authentication provider. +// It validates CSRF, exchanges code-token for id-token and extracts groups from the id-token. func (s *Server) AuthCallbackHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Logging setup @@ -232,7 +238,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Check for CSRF cookie c, err := r.Cookie(s.config.CSRFCookieName) if err != nil { - logger.Warnf("Missing CSRF cookie: %v", err) + logger.Warnf("missing CSRF cookie: %v", err) http.Error(w, "Not authorized", 401) return } @@ -240,7 +246,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Validate state valid, redirect, err := authentication.ValidateCSRFCookie(r, c) if !valid { - logger.Warnf("Error validating CSRF cookie: %v", err) + logger.Warnf("error validating CSRF cookie: %v", err) http.Error(w, "Not authorized", 401) return } @@ -259,9 +265,9 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } oauth2Config := oauth2.Config{ - ClientID: s.config.ClientId, + ClientID: s.config.ClientID, ClientSecret: s.config.ClientSecret, - RedirectURL: s.authenticator.RedirectUri(r), + RedirectURL: s.authenticator.ComposeRedirectURI(r), Endpoint: provider.Endpoint(), Scopes: scope, } @@ -283,7 +289,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Parse and verify ID Token payload. - verifier := provider.Verifier(&oidc.Config{ClientID: s.config.ClientId}) + verifier := provider.Verifier(&oidc.Config{ClientID: s.config.ClientID}) idToken, err := verifier.Verify(s.config.OIDCContext, rawIDToken) if err != nil { logger.Warnf("failed to verify token: %v", err) @@ -306,9 +312,9 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { http.SetCookie(w, s.authenticator.MakeIDCookie(r, email.(string))) logger.WithFields(logrus.Fields{ "user": claims["email"].(string), - }).Infof("Generated auth cookie") + }).Infof("generated auth cookie") } else { - logger.Errorf("failed to get email claims session") + logger.Errorf("no email claim present in the ID token") } // If name in null, empty or whitespace, use email address for name @@ -320,7 +326,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { http.SetCookie(w, s.authenticator.MakeNameCookie(r, name.(string))) logger.WithFields(logrus.Fields{ "name": name.(string), - }).Infof("Generated name cookie") + }).Infof("generated name cookie") // Mapping groups groups := []string{} @@ -331,7 +337,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { groups[i] = v.(string) } } else { - logger.Errorf("failed to get groups claims session. GroupsAttributeName: %s", s.config.GroupsAttributeName) + logger.Errorf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName) } logger.Printf("creating claims session with groups: %v", groups) @@ -349,7 +355,13 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } } +// notAuthenticated is used to signal the request does not include a valid authentication data. +// If the request came from a browser (having "text/html" in the Accept header), authentication +// redirect is made to start a new auth flow. Otherwise the "Authenticatio expired" message +// is passed as one of the known content-types or as a plain text. func (s *Server) notAuthenticated(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) { + bestFormat := "" + // Redirect if request accepts HTML. Fail if request is AJAX, image, etc acceptHeader := r.Header.Get("Accept") acceptParts := strings.Split(acceptHeader, ",") @@ -359,24 +371,39 @@ func (s *Server) notAuthenticated(logger *logrus.Entry, w http.ResponseWriter, r if format == "text/html" || (i == 0 && format == "*/*") { s.authRedirect(logger, w, r) return + } else if strings.HasPrefix(format, "application/json") { + bestFormat = "json" + } else if strings.HasPrefix(format, "application/xml") { + bestFormat = "xml" } } logger.Warnf("Non-HTML request: %v", acceptHeader) - http.Error(w, "Authentication expired. Reload page to re-authenticate.", 401) + + errStr := "Authentication expired. Reload page to re-authenticate." + if bestFormat == "json" { + w.Header().Set("Content-Type", "application/json") + http.Error(w, `{"error": "`+errStr+`"}`, 401) + } else if bestFormat == "xml" { + w.Header().Set("Content-Type", "application/xml") + http.Error(w, ``+errStr+``, 401) + } else { + http.Error(w, errStr, 401) + } } +// authRedirect generates CSRF cookie and redirests to authentication provider to start the authentication flow. func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) { // Error indicates no cookie, generate nonce - err, nonce := authentication.Nonce() + nonce, err := authentication.GenerateNonce() if err != nil { - logger.Errorf("Error generating nonce, %v", err) + logger.Errorf("error generating nonce, %v", err) http.Error(w, "Service unavailable", 503) return } // Set the CSRF cookie http.SetCookie(w, s.authenticator.MakeCSRFCookie(r, nonce)) - logger.Debug("Set CSRF cookie and redirect to OIDC login") + logger.Debug("sending CSRF cookie and a redirect to OIDC login") // Mapping scope var scope []string @@ -392,20 +419,21 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht } oauth2Config := oauth2.Config{ - ClientID: s.config.ClientId, + ClientID: s.config.ClientID, ClientSecret: s.config.ClientSecret, - RedirectURL: s.authenticator.RedirectUri(r), + RedirectURL: s.authenticator.ComposeRedirectURI(r), Endpoint: s.config.OIDCProvider.Endpoint(), Scopes: scope, } - state := fmt.Sprintf("%s:%s", nonce, authentication.ReturnUrl(r)) + state := fmt.Sprintf("%s:%s", nonce, authentication.GetRequestURL(r)) http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound) return } +// logger provides a new logger enriched with request info func (s *Server) logger(r *http.Request, rule, msg string) *logrus.Entry { // Create logger logger := s.log.WithFields(logrus.Fields{ @@ -421,6 +449,7 @@ func (s *Server) logger(r *http.Request, rule, msg string) *logrus.Entry { return logger } +// getGroupsFromSession returns list of groups present in the session func (s *Server) getGroupsFromSession(r *http.Request) ([]string, error) { userInfo, err := s.userinfo.Get(r) if err != nil { @@ -429,6 +458,7 @@ func (s *Server) getGroupsFromSession(r *http.Request) ([]string, error) { return userInfo.Groups, nil } +// authzIsBypassed returns true if the request matches a bypass URI pattern func (s *Server) authzIsBypassed(r *http.Request) bool { for _, bypassURIPattern := range s.config.AuthZPassThrough { if authorization.PathMatches(r.URL.Path, bypassURIPattern) { @@ -439,8 +469,8 @@ func (s *Server) authzIsBypassed(r *http.Request) bool { return false } -// appends group prefix to groups -func (s *Server) getModifiedUserInfo(email string, groups []string) authorization.User { +// makeKubeUserInfo appends group prefix to all provided groups and adds "system:authenticated" group to the list +func (s *Server) makeKubeUserInfo(email string, groups []string) authorization.User { g := []string{"system:authenticated"} for _, group := range groups { g = append(g, fmt.Sprintf("%s%s", s.config.GroupClaimPrefix, group)) diff --git a/internal/util/cookiedomain.go b/internal/util/cookiedomain.go index 1c70944..0ff8e10 100644 --- a/internal/util/cookiedomain.go +++ b/internal/util/cookiedomain.go @@ -5,7 +5,7 @@ import ( "strings" ) -// CookieDomain +// CookieDomain represents a top-level cookie domain and helper functions on it type CookieDomain struct { Domain string `description:"TEST1"` DomainLen int `description:"TEST2"` @@ -13,6 +13,7 @@ type CookieDomain struct { SubDomainLen int `description:"TEST4"` } +// CookieDomains holds a list of cookie domains type CookieDomains []CookieDomain func NewCookieDomain(domain string) *CookieDomain { @@ -24,6 +25,7 @@ func NewCookieDomain(domain string) *CookieDomain { } } +// Match returns true if host matches the CookieDomain or is a subdomain of it func (c CookieDomain) Match(host string) bool { // Exact domain match? if host == c.Domain { @@ -38,17 +40,20 @@ func (c CookieDomain) Match(host string) bool { return false } +// UnmarshalFlag unmarshals the CookieDomain from the flag string func (c *CookieDomain) UnmarshalFlag(value string) error { *c = *NewCookieDomain(value) return nil } +// MarshalFlag marshals the CookieDomain into a flag string func (c *CookieDomain) MarshalFlag() (string, error) { return c.Domain, nil } // Legacy support for comma separated list of cookie domains +// UnmarshalFlag unmarshals the CookieDomains from the flag string func (c *CookieDomains) UnmarshalFlag(value string) error { if len(value) > 0 { for _, d := range strings.Split(value, ",") { @@ -59,6 +64,7 @@ func (c *CookieDomains) UnmarshalFlag(value string) error { return nil } +// MarshalFlag marshals the CookieDomain into a flag string func (c *CookieDomains) MarshalFlag() (string, error) { var domains []string for _, d := range *c { From 5fc2488c6d0f5ac022ddc2ea680fa324cdc43dea Mon Sep 17 00:00:00 2001 From: Mario Hros Date: Wed, 6 May 2020 21:33:45 +0200 Subject: [PATCH 2/9] optional case-insensitive username and groupname comparisons for RBAC + tests (cherry picked from commit 13427e825af8fb03f47fff896e455b397b823cc1) --- internal/authorization/rbac/rbac.go | 18 ++++++-- internal/authorization/rbac/rbac_test.go | 58 ++++++++++++++++++++++++ internal/configuration/config.go | 5 +- internal/handlers/server.go | 4 +- 4 files changed, 79 insertions(+), 6 deletions(-) diff --git a/internal/authorization/rbac/rbac.go b/internal/authorization/rbac/rbac.go index ddd9e28..68feb94 100644 --- a/internal/authorization/rbac/rbac.go +++ b/internal/authorization/rbac/rbac.go @@ -28,6 +28,8 @@ type Authorizer struct { syncDuration time.Duration informerStop chan struct{} selector labels.Selector + // If CaseInsensitiveSubjects is true, group and user names are compared case-insensitively (default false) + CaseInsensitiveSubjects bool } // NewAuthorizer creates a new RBAC authorizer @@ -68,7 +70,7 @@ func (ra *Authorizer) getRoleFromGroups(roleNameRef, subjectGroupName string, us // for every user group... for _, group := range userGroups { // if the group matches the group name in the subject, return the role - if group == subjectGroupName { + if compareSubjects(group, subjectGroupName, ra.CaseInsensitiveSubjects) { return ra.getRoleByName(roleNameRef) } } @@ -80,7 +82,7 @@ func (ra *Authorizer) getRoleFromGroups(roleNameRef, subjectGroupName string, us // getRoleForSubject gets the role bound to the subject depending on the subject kind (user or group). // Returns nil if there is no rule matching or an unknown subject Kind is provided func (ra *Authorizer) getRoleForSubject(user authorization.User, subject rbacv1.Subject, roleNameRef string) *rbacv1.ClusterRole { - if subject.Kind == "User" && subject.Name == user.GetName() { + if subject.Kind == "User" && compareSubjects(subject.Name, user.GetName(), ra.CaseInsensitiveSubjects) { return ra.getRoleByName(roleNameRef) } else if subject.Kind == "Group" { return ra.getRoleFromGroups(roleNameRef, subject.Name, user.GetGroups()) @@ -153,7 +155,7 @@ func verbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool { if ruleVerb == rbacv1.VerbAll { return true } - if strings.ToLower(ruleVerb) == strings.ToLower(requestedVerb) { + if strings.EqualFold(ruleVerb, requestedVerb) { return true } } @@ -173,3 +175,13 @@ func nonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL string) bool { } return false } + +// compareSubjects determines whether subjects names are equal (string equality is used). +// If caseInsensitive is true, the case of the characters is ignored, meaning "UserName" +// would be considered equal to "username". +func compareSubjects(s1, s2 string, caseInsensitive bool) bool { + if caseInsensitive == false { + return s1 == s2 + } + return strings.EqualFold(s1, s2) +} diff --git a/internal/authorization/rbac/rbac_test.go b/internal/authorization/rbac/rbac_test.go index ab5f6a9..16c63e7 100644 --- a/internal/authorization/rbac/rbac_test.go +++ b/internal/authorization/rbac/rbac_test.go @@ -156,3 +156,61 @@ func TestRBACAuthorizer_Authorize2(t *testing.T) { assert.NilError(t, err) assert.Equal(t, result, test.should) } + +func TestCaseInsensitiveSubjects(t *testing.T) { + type testCase struct { + authorizer authorization.Authorizer + user authorization.User + url string + should bool + } + + // declare roles and bindings all lower-case + role := makeRole("grafana-admin", []string{"*"}, []string{"/ops/portal/grafana", "/ops/portal/grafana/*"}) + rolebindings := makeClusterRoleBindingList( + makeBinding("User", "grafana-admin-boyle", "boyle@ldap.forumsys.com", "grafana-admin"), + makeBinding("Group", "grafana-admin-oidc-admins", "oidc:admins", "grafana-admin"), + ) + + // default authorizer + defaultAuthorizer := getRBACAuthorizer(&role, rolebindings) + + // case-insensitive authorizer + caseInsensitiveAuthorizer := getRBACAuthorizer(&role, rolebindings) + caseInsensitiveAuthorizer.CaseInsensitiveSubjects = true + + tests := []testCase{ + // users + { + authorizer: defaultAuthorizer, + user: authorization.User{Name: "Boyle@ldap.forumsys.com", Groups: []string{"oidc:chemists"}}, + url: "/ops/portal/grafana/rnJhmVJw.woff2", + should: false, // default case-sensitive user comparison shouldn't allow Boyle + }, + { + authorizer: caseInsensitiveAuthorizer, + user: authorization.User{Name: "Boyle@ldap.forumsys.com", Groups: []string{"oidc:chemists"}}, + url: "/ops/portal/grafana/rnJhmVJw.woff2", + should: true, // case-insensitive user comparison should allow Boyle + }, + // groups + { + authorizer: defaultAuthorizer, + user: authorization.User{Name: "agent47@ldap.forumsys.com", Groups: []string{"oidc:Admins"}}, + url: "/ops/portal/grafana/rnJhmVJw.woff2", + should: false, // default case-sensitive group comparison shouldn't allow Admins group + }, + { + authorizer: caseInsensitiveAuthorizer, + user: authorization.User{Name: "agent47@ldap.forumsys.com", Groups: []string{"oidc:Admins"}}, + url: "/ops/portal/grafana/rnJhmVJw.woff2", + should: true, // case-insensitive group comparison should allow Admins group + }, + } + + for _, test := range tests { + result, err := test.authorizer.Authorize(test.user, "GET", test.url) + assert.NilError(t, err) + assert.Equal(t, result, test.should) + } +} diff --git a/internal/configuration/config.go b/internal/configuration/config.go index 4766d6c..3ab0eba 100644 --- a/internal/configuration/config.go +++ b/internal/configuration/config.go @@ -61,8 +61,9 @@ type Config struct { GroupsAttributeName string `long:"groups-attribute-name" env:"GROUPS_ATTRIBUTE_NAME" default:"groups" description:"Map the correct attribute that contain the user groups"` // RBAC - EnableRBAC bool `long:"enable-rbac" env:"ENABLE_RBAC" description:"Indicates that RBAC support should be enabled"` - AuthZPassThrough CommaSeparatedList `long:"authz-pass-through" env:"AUTHZ_PASS_THROUGH" description:"One or more routes which bypass authorization checks"` + EnableRBAC bool `long:"enable-rbac" env:"ENABLE_RBAC" description:"Indicates that RBAC support should be enabled"` + AuthZPassThrough CommaSeparatedList `long:"authz-pass-through" env:"AUTHZ_PASS_THROUGH" description:"One or more routes which bypass authorization checks"` + CaseInsensitiveSubjects bool `long:"case-insensitive-subjects" env:"CASE_INSENSITIVE_SUBJECTS" description:"Make case-insensitive comparison of user and group names in the RBAC implementation"` // Storage EnableInClusterStorage bool `long:"enable-in-cluster-storage" env:"ENABLE_IN_CLUSTER_STORAGE" description:"When true, sessions are store in a kubernetes apiserver"` diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 6555d0c..7611ecf 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -48,7 +48,9 @@ func NewServer(userinfo v1alpha1.UserInfoInterface, clientset kubernetes.Interfa s.buildRoutes() s.userinfo = userinfo if config.EnableRBAC { - s.authorizer = rbac.NewAuthorizer(clientset) + rbac := rbac.NewAuthorizer(clientset) + rbac.CaseInsensitiveSubjects = config.CaseInsensitiveSubjects + s.authorizer = rbac } return s } From b89a5cdc64e05794860dabe2acb1e9bdafe14691 Mon Sep 17 00:00:00 2001 From: Mario Hros Date: Wed, 6 May 2020 21:34:28 +0200 Subject: [PATCH 3/9] pass the logger optionally to the RBAC implementation and describe resync cherry-pick 0720eae --- internal/authorization/rbac/rbac.go | 33 ++++++++++++++++-------- internal/authorization/rbac/rbac_test.go | 2 +- internal/handlers/server.go | 2 +- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/internal/authorization/rbac/rbac.go b/internal/authorization/rbac/rbac.go index 68feb94..3331eac 100644 --- a/internal/authorization/rbac/rbac.go +++ b/internal/authorization/rbac/rbac.go @@ -2,6 +2,7 @@ package rbac import ( "log" + "os" "strings" "time" @@ -16,11 +17,21 @@ import ( ) const ( - cacheSyncDuration = time.Minute * 10 + // How often the informer should perform a resync (list all resources and rehydrate the informer’s store). + // This creates a higher guarantee that your informer’s store has a perfect picture of the resources it is watching. + // There are situations where events can be missed entirely and resyncing every so often solves this. + // Setting to 0 disables the resync and makes the informer subscribe to individual updates only. + defaultResyncDuration = time.Minute * 10 ) +// Logger is an interface for basic log output +type Logger interface { + Printf(format string, v ...interface{}) +} + // Authorizer implements the authorizer by watching and using ClusterRole and ClusterRoleBinding Kubernetes (RBAC) objects type Authorizer struct { + logger Logger clientset kubernetes.Interface clusterRoleLister rbaclisterv1.ClusterRoleLister clusterRoleBindingLister rbaclisterv1.ClusterRoleBindingLister @@ -32,11 +43,16 @@ type Authorizer struct { CaseInsensitiveSubjects bool } -// NewAuthorizer creates a new RBAC authorizer -func NewAuthorizer(clientset kubernetes.Interface) *Authorizer { +// NewAuthorizer creates a new RBAC authorizer. Logger can be nil to use standard error logger. +func NewAuthorizer(clientset kubernetes.Interface, logger Logger) *Authorizer { + if logger == nil { + logger = log.New(os.Stderr, "rbac", log.LstdFlags) + } + authz := &Authorizer{ + logger: logger, clientset: clientset, - syncDuration: cacheSyncDuration, + syncDuration: defaultResyncDuration, selector: labels.NewSelector(), informerStop: make(chan struct{}), } @@ -51,14 +67,9 @@ func (ra *Authorizer) getRoleByName(name string) *rbacv1.ClusterRole { clusterRole, err := ra.clusterRoleLister.Get(name) if err != nil { if errors.IsNotFound(err) { - // TFA's "internal" package doesn't make sense for expanding functionality. - // IMO, TFA should be rewritten completely using current golang design standards - // TODO(jr): Rewrite TFA as a lightweight forward proxy - // ^^ using stdlib log because I don't want to parse the configuration file again for - // two log messages... (jr) (or muck up my interfaces by passing in a log object..) - log.Printf("role binding %s is bound to non-existent role", name) + ra.logger.Printf("role binding is bound to non-existent role %s", name) } else { - log.Printf("error getting role bound to %s: %v", name, err) + ra.logger.Printf("error getting role %s from role binding: %v", name, err) } return nil } diff --git a/internal/authorization/rbac/rbac_test.go b/internal/authorization/rbac/rbac_test.go index 16c63e7..64bcaef 100644 --- a/internal/authorization/rbac/rbac_test.go +++ b/internal/authorization/rbac/rbac_test.go @@ -25,7 +25,7 @@ type testCase struct { } func getRBACAuthorizer(objs ...runtime.Object) *Authorizer { - return NewAuthorizer(fake.NewSimpleClientset(objs...)) + return NewAuthorizer(fake.NewSimpleClientset(objs...), nil) } func makeRole(name string, verbs, urls []string) rbacv1.ClusterRole { diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 7611ecf..0586f01 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -48,7 +48,7 @@ func NewServer(userinfo v1alpha1.UserInfoInterface, clientset kubernetes.Interfa s.buildRoutes() s.userinfo = userinfo if config.EnableRBAC { - rbac := rbac.NewAuthorizer(clientset) + rbac := rbac.NewAuthorizer(clientset, s.log) rbac.CaseInsensitiveSubjects = config.CaseInsensitiveSubjects s.authorizer = rbac } From 39539aa1f682010131e5d12413cae2072a6084d8 Mon Sep 17 00:00:00 2001 From: Mario Hros Date: Fri, 8 May 2020 18:15:20 +0200 Subject: [PATCH 4/9] [breaking change] wildcard matching improved and additional matching types added Original wildcard prefix match now allows multiple * characters. For increased safety, a single '*' character now matches within one path segment only. To match any number of path segments, two consecutive charaters must be specified '**'. In addition to original "prefix match", full URL match is possible: 1. Wildcard Prefix Match `/admin/*` matches /admin/overview, /admin/users, *not* /admin/users/1 `/admin/**` matches what '/admin/*' matched + also '/admin/users/1' 2. Full URL Wildcard Match `*://a.com/admin` matches http://a.com/admin and https://a.com/admin `*://a.com/**` matches everything under http://a.com/ and https://a.com/ `https://b.com/admin` matches https://b.com/admin only 3. Full URL Regular Expression Match (prefixed by ~ character!) `~^https?://[cd].com/.*` matches everything under http://c.com/, http://d.com/ and their https versions Tests were extended for new functionality and updated for the fact that single '*' now matches within the one path component only. cherry-pick 36c3eee --- cmd/main.go | 13 +- go.mod | 1 + internal/authentication/auth.go | 84 +++++-------- internal/authentication/auth_test.go | 57 ++++----- internal/authorization/authorizer.go | 4 +- internal/authorization/rbac/rbac.go | 31 ++++- internal/authorization/rbac/rbac_test.go | 133 ++++++++++++++++++--- internal/authorization/urlpatterns.go | 126 +++++++++++++++++++ internal/authorization/urlpatterns_test.go | 82 +++++++++++++ internal/authorization/util.go | 11 -- internal/authorization/util_test.go | 29 ----- internal/configuration/config.go | 35 +++--- internal/configuration/config_test.go | 3 - internal/handlers/server.go | 64 +++++++--- internal/handlers/server_test.go | 115 ++++++++++-------- 15 files changed, 558 insertions(+), 230 deletions(-) create mode 100644 internal/authorization/urlpatterns.go create mode 100644 internal/authorization/urlpatterns_test.go delete mode 100644 internal/authorization/util.go delete mode 100644 internal/authorization/util_test.go diff --git a/cmd/main.go b/cmd/main.go index 5228ec3..ac656bc 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "time" + "fmt" "github.com/gorilla/sessions" logger "github.com/mesosphere/traefik-forward-auth/internal/log" @@ -20,7 +21,11 @@ import ( // Main func main() { // Parse options - config := configuration.NewGlobalConfig(os.Args[1:]) + config, err := configuration.NewGlobalConfig(os.Args[1:]) + if err != nil { + fmt.Printf("%+v\n", err) + os.Exit(1) + } // Setup logger log := logger.NewDefaultLogger(config.LogLevel, config.LogFormat) @@ -47,7 +52,9 @@ func main() { var userInfoStore v1alpha1.UserInfoInterface if !config.EnableInClusterStorage { // Prepare cookie session store (first key is for auth, the second one for encryption) - cookieStore := sessions.NewCookieStore(config.Secret, []byte(config.SessionKey)) + hashKey := []byte(config.SecretString) + blockKey := []byte(config.EncryptionKeyString) + cookieStore := sessions.NewCookieStore(hashKey, blockKey) cookieStore.Options.MaxAge = int(config.Lifetime / time.Second) cookieStore.Options.HttpOnly = true cookieStore.Options.Secure = !config.InsecureCookie @@ -61,7 +68,7 @@ func main() { userInfoStore = cluster.NewClusterStore( clientset, config.ClusterStoreNamespace, - string(config.Secret), + config.SecretString, config.Lifetime, time.Duration(config.ClusterStoreCacheTTL)*time.Second, authenticator) diff --git a/go.mod b/go.mod index 4031370..6af3f17 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/googleapis/gnostic v0.3.1 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/sessions v1.2.0 + github.com/gorilla/securecookie v1.1.1 github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff // indirect github.com/jonboulle/clockwork v0.1.0 // indirect github.com/json-iterator/go v1.1.8 // indirect diff --git a/internal/authentication/auth.go b/internal/authentication/auth.go index 1dee77e..a9de465 100644 --- a/internal/authentication/auth.go +++ b/internal/authentication/auth.go @@ -1,67 +1,51 @@ package authentication import ( - "crypto/hmac" "crypto/rand" - "crypto/sha256" - "encoding/base64" "errors" "fmt" "net/http" - "strconv" "strings" "time" + "github.com/gorilla/securecookie" + "github.com/mesosphere/traefik-forward-auth/internal/configuration" ) type Authenticator struct { - config *configuration.Config + config *configuration.Config + secureCookie *securecookie.SecureCookie } func NewAuthenticator(config *configuration.Config) *Authenticator { - return &Authenticator{config} + cookieMaxAge := int(config.Lifetime / time.Second) + hashKey := []byte(config.SecretString) + blockKey := []byte(config.EncryptionKeyString) + + return &Authenticator{ + config: config, + secureCookie: securecookie.New(hashKey, blockKey).MaxAge(cookieMaxAge), + } +} + +type ID struct { + Email string + Token string } // Request Validation // ValidateCookie validates the ID cookie in the request // IDCookie = hash(secret, cookie domain, email, expires)|expires|email|group -func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { - parts := strings.Split(c.Value, "|") - - if len(parts) != 3 { - return "", errors.New("invalid cookie format") - } - - mac, err := base64.URLEncoding.DecodeString(parts[0]) - if err != nil { - return "", errors.New("unable to decode cookie mac") - } - - expectedSignature := a.cookieSignature(r, parts[2], parts[1]) - expected, err := base64.URLEncoding.DecodeString(expectedSignature) - if err != nil { - return "", errors.New("unable to generate mac") - } - - // Valid token? - if !hmac.Equal(mac, expected) { - return "", errors.New("invalid cookie mac") - } - - expires, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return "", errors.New("unable to parse cookie expiry") - } +func (a *Authenticator) ValidateCookie(r *http.Request, c *http.Cookie) (*ID, error) { + var data ID - // Has it expired? - if time.Unix(expires, 0).Before(time.Now()) { - return "", errors.New("cookie has expired") + if err := a.secureCookie.Decode(a.config.CookieName, c.Value, &data); err != nil { + return nil, err } - // Looks valid - return parts[2], nil + return &data, nil } // ValidateEmail validates that the provided email ends with one of the configured Domains or is part of the configured Whitelist. @@ -117,14 +101,21 @@ func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) { // Cookie methods // MakeIDCookie creates an auth cookie -func (a *Authenticator) MakeIDCookie(r *http.Request, email string) *http.Cookie { +func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string) *http.Cookie { expires := a.cookieExpiry() - mac := a.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix())) - value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email) + data := &ID{ + Email: email, + Token: token, + } + + encoded, err := a.secureCookie.Encode(a.config.CookieName, data) + if err != nil { + return nil + } return &http.Cookie{ Name: a.config.CookieName, - Value: value, + Value: encoded, Path: "/", Domain: a.GetCookieDomain(r), HttpOnly: true, @@ -249,15 +240,6 @@ func (a *Authenticator) matchCookieDomains(domain string) (bool, string) { return false, p[0] } -// cookieSignature creates a cookie hmac -func (a *Authenticator) cookieSignature(r *http.Request, email, expires string) string { - hash := hmac.New(sha256.New, a.config.Secret) - hash.Write([]byte(a.GetCookieDomain(r))) - hash.Write([]byte(email)) - hash.Write([]byte(expires)) - return base64.URLEncoding.EncodeToString(hash.Sum(nil)) -} - // Get cookie expirary func (a *Authenticator) cookieExpiry() time.Time { return time.Now().Local().Add(a.config.Lifetime) diff --git a/internal/authentication/auth_test.go b/internal/authentication/auth_test.go index 7c4d21c..b1cb63e 100644 --- a/internal/authentication/auth_test.go +++ b/internal/authentication/auth_test.go @@ -5,67 +5,71 @@ import ( "github.com/mesosphere/traefik-forward-auth/internal/configuration" "github.com/mesosphere/traefik-forward-auth/internal/util" "net/http" - "strings" "testing" "time" "github.com/stretchr/testify/assert" ) +var ( + testAuthKey1 = "4Zhbg4n22r4I8Kdg1gHMzRWQpT7TOArD" + testEncKey1 = "8jAnK6NGuzEuH3y13V+5Bm2jgp5bv8ku" +) + +func newTestConfig(authKey, encKey string) *configuration.Config { + c, _ := configuration.NewConfig([]string{}) + c.SecretString = authKey + c.EncryptionKeyString = encKey + + return c +} + /** * Tests */ func TestAuthValidateCookie(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) a := NewAuthenticator(config) r, _ := http.NewRequest("GET", "http://example.com", nil) c := &http.Cookie{} - // Should require 3 parts + // Should not accept an empty value c.Value = "" _, err := a.ValidateCookie(r, c) if assert.Error(err) { - assert.Equal("invalid cookie format", err.Error()) - } - c.Value = "1|2" - _, err = a.ValidateCookie(r, c) - if assert.Error(err) { - assert.Equal("invalid cookie format", err.Error()) - } - c.Value = "1|2|3|4" - _, err = a.ValidateCookie(r, c) - if assert.Error(err) { - assert.Equal("invalid cookie format", err.Error()) + assert.Equal("securecookie: the value is not valid", err.Error()) } // Should catch invalid mac - c.Value = "MQ==|2|3" + c.Value = "MQ==" _, err = a.ValidateCookie(r, c) if assert.Error(err) { - assert.Equal("invalid cookie mac", err.Error()) + assert.Equal("securecookie: the value is not valid", err.Error()) } // Should catch expired config.Lifetime = time.Second * time.Duration(-1) - c = a.MakeIDCookie(r, "test@test.com") + a = NewAuthenticator(config) + c = a.MakeIDCookie(r, "test@test.com", "") _, err = a.ValidateCookie(r, c) if assert.Error(err) { - assert.Equal("cookie has expired", err.Error()) + assert.Equal("securecookie: expired timestamp", err.Error()) } // Should accept valid cookie config.Lifetime = time.Second * time.Duration(10) - c = a.MakeIDCookie(r, "test@test.com") - email, err := a.ValidateCookie(r, c) + a = NewAuthenticator(config) + c = a.MakeIDCookie(r, "test@test.com", "") + id, err := a.ValidateCookie(r, c) assert.Nil(err, "valid request should not return an error") - assert.Equal("test@test.com", email, "valid request should return user email") + assert.Equal("test@test.com", id.Email, "valid request should return user email") } func TestAuthValidateEmail(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) a := NewAuthenticator(config) // Should allow any @@ -106,7 +110,7 @@ func TestAuthValidateEmail(t *testing.T) { // } func getConfigWithLifetime() *configuration.Config { - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) // Lifetime is set during validation, so we short circuit it here config.Lifetime = time.Second * time.Duration(config.LifetimeString) return config @@ -120,10 +124,9 @@ func TestAuthMakeCookie(t *testing.T) { r, _ := http.NewRequest("GET", "http://app.example.com", nil) r.Header.Add("X-Forwarded-Host", "app.example.com") - c := a.MakeIDCookie(r, "test@example.com") + c := a.MakeIDCookie(r, "test@example.com", "") assert.Equal("_forward_auth", c.Name) - parts := strings.Split(c.Value, "|") - assert.Len(parts, 3, "cookie should be 3 parts") + assert.Greater(len(c.Value), 18, "encoded securecookie should be longer") _, err := a.ValidateCookie(r, c) assert.Nil(err, "should generate valid cookie") assert.Equal("/", c.Path) @@ -135,7 +138,7 @@ func TestAuthMakeCookie(t *testing.T) { config.CookieName = "testname" config.InsecureCookie = true - c = a.MakeIDCookie(r, "test@example.com") + c = a.MakeIDCookie(r, "test@example.com", "") assert.Equal("testname", c.Name) assert.False(c.Secure) } diff --git a/internal/authorization/authorizer.go b/internal/authorization/authorizer.go index a3de839..5fa5681 100644 --- a/internal/authorization/authorizer.go +++ b/internal/authorization/authorizer.go @@ -1,6 +1,8 @@ package authorization +import "net/url" + // Authorizer is the interface for implementing user authorization (check to see if the user can perform the action) type Authorizer interface { - Authorize(user User, requestVerb, requestResource string) (bool, error) + Authorize(user User, requestVerb string, resource *url.URL) (bool, error) } diff --git a/internal/authorization/rbac/rbac.go b/internal/authorization/rbac/rbac.go index 3331eac..383273c 100644 --- a/internal/authorization/rbac/rbac.go +++ b/internal/authorization/rbac/rbac.go @@ -2,6 +2,7 @@ package rbac import ( "log" + "net/url" "os" "strings" "time" @@ -133,7 +134,7 @@ func (ra *Authorizer) GetRolesBoundToUser(user authorization.User) (*rbacv1.Clus // Interface methods // Authorize performs the authorization logic -func (ra *Authorizer) Authorize(user authorization.User, requestVerb, requestResource string) (bool, error) { +func (ra *Authorizer) Authorize(user authorization.User, requestVerb string, requestURL *url.URL) (bool, error) { roles, err := ra.GetRolesBoundToUser(user) if err != nil { return false, err @@ -147,7 +148,7 @@ func (ra *Authorizer) Authorize(user authorization.User, requestVerb, requestRes // check all rules in the list of roles to see if any matches for _, role := range roles.Items { for _, rule := range role.Rules { - if verbMatches(&rule, requestVerb) && nonResourceURLMatches(&rule, requestResource) { + if verbMatches(&rule, requestVerb) && nonResourceURLMatches(&rule, requestURL) { return true, nil } } @@ -175,15 +176,33 @@ func verbMatches(rule *rbacv1.PolicyRule, requestedVerb string) bool { } // nonResourceURLMatches returns true if the requested URL matches a policy the rule -func nonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL string) bool { +func nonResourceURLMatches(rule *rbacv1.PolicyRule, requestedURL *url.URL) bool { for _, ruleURL := range rule.NonResourceURLs { if ruleURL == rbacv1.NonResourceAll { + // any (*) resource matches immediatelly return true - } - if authorization.PathMatches(requestedURL, ruleURL) { - return true + } else if len(ruleURL) > 0 { + // determine match type depending on the first rune: + + if ruleURL[0] == '~' { // regular expression match against the full url requested + fullURLWithoutQuery := requestedURL.Scheme + "://" + requestedURL.Host + requestedURL.Path + if authorization.URLMatchesRegexp(fullURLWithoutQuery, ruleURL[1:]) { + return true // return only if it matched + } + } else if ruleURL[0] == '/' { // path-only prefix match with optional wildcards (backward-compatible) + if authorization.URLMatchesWildcardPattern(requestedURL.Path, ruleURL) { + return true // return only if it matched + } + } else { // full url path-only prefix match with optional wildcards + fullURLWithoutQuery := requestedURL.Scheme + "://" + requestedURL.Host + requestedURL.Path + if authorization.URLMatchesWildcardPattern(fullURLWithoutQuery, ruleURL) { + return true // return only if it matched + } + } } } + + // no rule matched return false } diff --git a/internal/authorization/rbac/rbac_test.go b/internal/authorization/rbac/rbac_test.go index 64bcaef..c384b84 100644 --- a/internal/authorization/rbac/rbac_test.go +++ b/internal/authorization/rbac/rbac_test.go @@ -1,9 +1,11 @@ package rbac import ( + "net/url" "testing" - "gotest.tools/assert" + //"gotest.tools/assert" + "github.com/stretchr/testify/assert" rbacv1 "k8s.io/api/rbac/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -20,10 +22,19 @@ const ( type testCase struct { user authorization.User verb string - url string + url *url.URL should bool } +// makeURL makes url.URL object from relative path (e.g. /test) or full URL (e.g. http://domain.com/test) +func makeURL(str string) *url.URL { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + return u +} + func getRBACAuthorizer(objs ...runtime.Object) *Authorizer { return NewAuthorizer(fake.NewSimpleClientset(objs...), nil) } @@ -100,23 +111,39 @@ func TestRBACAuthorizer_GetRoles(t *testing.T) { u1 := authorization.User{Name: "u1"} r, err := a.GetRolesBoundToUser(u1) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(r.Items), 2) - assert.Equal(t, r.Items[0].Name, "r1") - assert.Equal(t, r.Items[1].Rules[0].NonResourceURLs[0], "/admin") + + // order of items in the returned array is actually not fixed so an iteration is required: + + hasr1Role := false + for _, role := range r.Items { + if role.Name == "r1" { + hasr1Role = true + } + } + assert.Equal(t, hasr1Role, true) + + hasAdminPathInRole := false + for _, role := range r.Items { + if role.Rules[0].NonResourceURLs[0] == "/admin" { + hasAdminPathInRole = true + } + } + assert.Equal(t, hasAdminPathInRole, true) u2 := authorization.User{Name: "u2", Groups: []string{"g1", "g2"}} r, err = a.GetRolesBoundToUser(u2) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(r.Items), 1) assert.Equal(t, r.Items[0].Name, "r3") } func TestRBACAuthorizer_Authorize(t *testing.T) { tests := []testCase{ - {authorization.User{Name: "u1"}, "get", "/", allow}, - {authorization.User{Name: "u1"}, "post", "/", deny}, + {authorization.User{Name: "u1"}, "get", makeURL("/"), allow}, + {authorization.User{Name: "u1"}, "post", makeURL("/"), deny}, } roles := makeClusterRoleList( @@ -134,7 +161,7 @@ func TestRBACAuthorizer_Authorize(t *testing.T) { for _, test := range tests { result, err := a.Authorize(test.user, test.verb, test.url) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, result, test.should) } } @@ -143,17 +170,17 @@ func TestRBACAuthorizer_Authorize2(t *testing.T) { test := testCase{ user: authorization.User{Name: "boyle@ldap.forumsys.com", Groups: []string{"oidc:chemists"}}, - url: "/ops/portal/grafana/public/fonts/roboto/RxZJdnzeo3R5zSexge8UUVtXRa8TVwTICgirnJhmVJw.woff2", + url: makeURL("/ops/portal/grafana/public/fonts/roboto/RxZJdnzeo3R5zSexge8UUVtXRa8TVwTICgirnJhmVJw.woff2"), should: allow, } - role := makeRole("grafana-admin", []string{"*"}, []string{"/ops/portal/grafana", "/ops/portal/grafana/*"}) + role := makeRole("grafana-admin", []string{"*"}, []string{"/ops/portal/grafana", "/ops/portal/grafana/**"}) rolebinding := makeBinding("User", "grafana-admin-boyle", "boyle@ldap.forumsys.com", "grafana-admin") a := getRBACAuthorizer(&role, &rolebinding) result, err := a.Authorize(test.user, test.verb, test.url) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, result, test.should) } @@ -209,8 +236,86 @@ func TestCaseInsensitiveSubjects(t *testing.T) { } for _, test := range tests { - result, err := test.authorizer.Authorize(test.user, "GET", test.url) - assert.NilError(t, err) + result, err := test.authorizer.Authorize(test.user, "GET", makeURL(test.url)) + assert.NoError(t, err) assert.Equal(t, result, test.should) } } + +func TestRBACAuthorizer_AuthorizePatternTypes(t *testing.T) { + tests := []testCase{ + // user with "visitor" role only can GET root "/" path on any domain but cannot access any "/admin" or other URLs + {authorization.User{Name: "u1"}, "get", makeURL("/"), allow}, + {authorization.User{Name: "u1"}, "get", makeURL("https://testdomain.com/"), allow}, + {authorization.User{Name: "u1"}, "post", makeURL("/admin"), deny}, + {authorization.User{Name: "u1"}, "post", makeURL("https://testdomain.com/admin"), deny}, + {authorization.User{Name: "u1"}, "post", makeURL("/reports"), deny}, + {authorization.User{Name: "u1"}, "post", makeURL("https://testdomain.com/reports"), deny}, + {authorization.User{Name: "u1"}, "post", makeURL("https://finance.com/finances"), deny}, + + // user with "visitor" & "admin" roles can GET root "/" path on any domain and POST/ANY to "/admin" path on every domain + {authorization.User{Name: "u2"}, "get", makeURL("/"), allow}, + {authorization.User{Name: "u2"}, "get", makeURL("https://testdomain.com/"), allow}, + {authorization.User{Name: "u2"}, "post", makeURL("/admin"), allow}, + {authorization.User{Name: "u2"}, "post", makeURL("https://testdomain.com/admin"), allow}, + {authorization.User{Name: "u2"}, "delete", makeURL("https://testdomain.com/admin"), allow}, + {authorization.User{Name: "u2"}, "delete", makeURL("https://testdomain.com/boss"), deny}, + + // user with "visitor" & "testdomain-only-admin-poster" role can GET "/" path on any domain + // but POST to "/admin" on testdomain.com only + {authorization.User{Name: "u3"}, "get", makeURL("/"), allow}, + {authorization.User{Name: "u3"}, "get", makeURL("https://testdomain.com/"), allow}, + {authorization.User{Name: "u3"}, "post", makeURL("/admin"), deny}, + {authorization.User{Name: "u3"}, "post", makeURL("https://customdomain.com/admin"), deny}, + {authorization.User{Name: "u3"}, "post", makeURL("https://testdomain.com/admin"), allow}, + {authorization.User{Name: "u3"}, "delete", makeURL("https://testdomain.com/admin"), deny}, + + // user with "wilddomain-only-admin-poster" role can only POST to "/admin" on URLs matching "*://*domain.com/admin" + {authorization.User{Name: "u4"}, "post", makeURL("https://facebook.com/admin"), deny}, + {authorization.User{Name: "u4"}, "post", makeURL("https://customdomain.com/admin"), allow}, + {authorization.User{Name: "u4"}, "post", makeURL("https://testdomain.com/admin"), allow}, + {authorization.User{Name: "u4"}, "post", makeURL("https://testdomain.com/admin/res/theme.css"), deny}, // no * at the end + {authorization.User{Name: "u4"}, "post", makeURL("http://customdomain.com/admin"), allow}, // same as https:// version + {authorization.User{Name: "u4"}, "post", makeURL("http://testdomain.com/admin"), allow}, // same as https:// version + {authorization.User{Name: "u4"}, "post", makeURL("https://testdomain.com/reports"), deny}, + + // user with "https-regexpdomain-admin-poster" can post to anything under https://(first|second)domain/admin path + {authorization.User{Name: "u5"}, "get", makeURL("https://firstdomain.com/"), deny}, + {authorization.User{Name: "u5"}, "get", makeURL("https://firstdomain.com/admin"), deny}, + {authorization.User{Name: "u5"}, "post", makeURL("https://firstdomain.com/admin"), allow}, + {authorization.User{Name: "u5"}, "post", makeURL("https://firstdomain.com/admin/users/create"), allow}, + {authorization.User{Name: "u5"}, "post", makeURL("https://seconddomain.com/admin/users/create"), allow}, + } + + roles := makeClusterRoleList( + makeRole("visitor", []string{"get"}, []string{"/"}), + makeRole("admin", []string{"*"}, []string{"/admin"}), + makeRole("testdomain-only-admin-poster", []string{"post"}, []string{"https://testdomain.com/admin"}), + makeRole("wilddomain-only-admin-poster", []string{"post"}, []string{"*://*domain.com/admin"}), + makeRole("https-regexpdomain-admin-poster", []string{"post"}, []string{"~^https://(first|second)domain.com/admin"}), + ) + + bindings := makeClusterRoleBindingList( + makeBinding("User", "u1b1", "u1", "visitor"), + + makeBinding("User", "u2b1", "u2", "visitor"), + makeBinding("User", "u2b2", "u2", "admin"), + + makeBinding("User", "u3b1", "u3", "visitor"), + makeBinding("User", "u3b2", "u3", "testdomain-only-admin-poster"), + + makeBinding("User", "u4b1", "u4", "wilddomain-only-admin-poster"), + + makeBinding("User", "u5b1", "u5", "https-regexpdomain-admin-poster"), + ) + a := getRBACAuthorizer(roles, bindings) + + for _, test := range tests { + result, err := a.Authorize(test.user, test.verb, test.url) + assert.NoError(t, err) + + if !assert.Equal(t, result, test.should) { + t.Logf("Authorize(%v, %v, %v) != %v", test.user, test.verb, test.url, test.should) + } + } +} diff --git a/internal/authorization/urlpatterns.go b/internal/authorization/urlpatterns.go new file mode 100644 index 0000000..0eb81ea --- /dev/null +++ b/internal/authorization/urlpatterns.go @@ -0,0 +1,126 @@ +package authorization + +import ( + "errors" + "regexp" + "strings" + "sync" +) + +var ( + globalRECache = newRegexpCache() + invalidExpr = ®exp.Regexp{} + errInvalidExpr = errors.New("invalid regular expression") +) + +type regexpCache struct { + mu sync.RWMutex + cache map[string]*regexp.Regexp +} + +func newRegexpCache() *regexpCache { + return ®expCache{ + cache: make(map[string]*regexp.Regexp), + } +} + +// get returns regexp cached under "expr" key or nil if not cached +func (rc *regexpCache) get(expr string) *regexp.Regexp { + rc.mu.RLock() + defer rc.mu.RUnlock() + + re, _ := rc.cache[expr] + return re +} + +// wildcardPatternToRegexp converts pattern containing optional * characters +// to a regular expression string. A special care is taken to quote +// any regular expression characters in the input pattern first. +func (rc *regexpCache) wildcardPatternToRegexp(pattern string) string { + // quote all regexp metacharacters to make the safe expression which would match + // the input as being a literal string (basically a regexp for string equality test) + pattern = regexp.QuoteMeta(pattern) + // replace two ** input characters (now quoted by '\') with an expression to match anything + pattern = strings.ReplaceAll(pattern, `\*\*`, `.*`) + // replace the remaining single escaped '*' with an expression to match continous stream + // (by using non-greedy ? specifier) of characters not containing any slash (path separator) + pattern = strings.ReplaceAll(pattern, `\*`, `[^/]*?`) + // request pattern to match the subject fully by adding beginning and ending anchors + return `^` + pattern + `$` +} + +// GetOrCompile attempts to get compiled regexp from the cache or attempts to compile it and cache. +// If the expr is not a valid expression, the function forwards the error from regexp.Compile +// If asWildcard is true, the expression will be interpreted as a wildcard pattern +func (rc *regexpCache) GetOrCompile(expr string, asWildcard bool) (*regexp.Regexp, error) { + var err error + + // attempt to get already-compiled regexp from cache first + re := rc.get(expr) + if re != nil { + if re == invalidExpr { + // if invalid expression is cached, return early with an error + return nil, errInvalidExpr + } + return re, nil // return cached Regexp object + } + + if strings.TrimSpace(expr) == "" { + // mark empty expr as failed compilation for extra safety + err = errInvalidExpr + re = nil + } else { + // attempt to compile a new regexp + exprToCompile := expr + if asWildcard { + // if wildcard mode requested, convert first + exprToCompile = rc.wildcardPatternToRegexp(expr) + } + re, err = regexp.Compile(exprToCompile) + } + + // cache failed regexpes as invalid to prevent their re-compilation + if err != nil { + re = invalidExpr + } + + // store in the cache under the input "expr" as the key + rc.mu.Lock() + rc.cache[expr] = re + rc.mu.Unlock() + + return re, err +} + +// MatchString returns true if subject matches the expression expr. +// The expr parameter can be a regular expression string or wildcard pattern (if asWildcard is true). +// Wildcard pattern may not contain any '*' character in which case a direct string comparison is performed. +func (rc *regexpCache) MatchString(subject string, expr string, asWildcard bool) bool { + // Attempt to speed up asWildcard=true patterns by doing quick direct literal comparison first. + // Wildcard patterns always match themselves ("x*x", "x*x", true) = true and the speed tests + // suggest s negative result of this comparison do not make significant difference in total speed + if asWildcard && expr == subject { + return true + } + + re, err := rc.GetOrCompile(expr, asWildcard) + if err != nil { + return false + } + + return re.MatchString(subject) +} + +// URLMatchesRegexp returns true if the URL matches the regular expresson +func URLMatchesRegexp(url, regex string) bool { + return globalRECache.MatchString(url, regex, false) +} + +// URLMatchesWildcardPattern returns true if the URL matches the pattern containing optional wildcard '*' characters +func URLMatchesWildcardPattern(url, pattern string) bool { + // original implementation: + // return pattern == url || + // (strings.HasSuffix(pattern, "*") && strings.HasPrefix(url, strings.TrimRight(pattern, "*"))) + + return globalRECache.MatchString(url, pattern, true) +} diff --git a/internal/authorization/urlpatterns_test.go b/internal/authorization/urlpatterns_test.go new file mode 100644 index 0000000..6cc40bb --- /dev/null +++ b/internal/authorization/urlpatterns_test.go @@ -0,0 +1,82 @@ +package authorization + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWildcardMatches(t *testing.T) { + type test struct { + pattern string + url string + matches bool + } + + var testCases = []test{ + {pattern: "", url: "", matches: true}, + {pattern: "", url: "/", matches: false}, + {pattern: "/", url: "", matches: false}, + {pattern: "/", url: "/", matches: true}, + {pattern: "/xyz/*/subpath3", url: "/xyz/subpath1/subpath2", matches: false}, + {pattern: "/ops/portal/", url: "/ops/portal/admin", matches: false}, + {pattern: "/ops/portal/*.png", url: "/ops/portal/xyz.png", matches: true}, + {pattern: "/ops/portal/*.png", url: "/ops/portal/res/xyz.png", matches: false}, + {pattern: "/ops/portal/**/*.png", url: "/ops/portal/res/xyz.png", matches: true}, + {pattern: "/ops/portal/kibana", url: "/ops/portal/kibana/app/kibana", matches: false}, + {pattern: "/ops/portal/kibana/**", url: "/ops/portal/kibana/app/kibana", matches: true}, + {pattern: "/ops/portal/grafana/**", url: "/ops/portal/grafana/public/img/fav32.png", matches: true}, + {pattern: "/ops/portal/grafana/**", url: "/ops/portal/grafana/public/build/runtime.3932bda029d2299a9d96.js", matches: true}, + } + + for _, c := range testCases { + if !assert.Equal(t, c.matches, URLMatchesWildcardPattern(c.url, c.pattern)) { + t.Logf("URLMatchesWildcardPattern(%v, %v) != %v", c.url, c.pattern, c.matches) + } + } +} + +func TestRegexpMatches(t *testing.T) { + type test struct { + pattern string + url string + matches bool + } + + var testCases = []test{ + {pattern: ``, url: "", matches: false}, + {pattern: ``, url: "/", matches: false}, + {pattern: `/`, url: "", matches: false}, + {pattern: `/`, url: "/", matches: true}, + {pattern: `https?://(my|our)domain.com/`, url: "http://mydomain.com/", matches: true}, + {pattern: `https?://(my|our)domain.com/`, url: "https://mydomain.com/", matches: true}, + {pattern: `https?://(my|our)domain.com/`, url: "http://ourdomain.com/", matches: true}, + {pattern: `https?://(my|our)domain.com/`, url: "http://ourdomain.com/admin", matches: true}, + {pattern: `https?://(my|our)domain.com/`, url: "http://theirdomain.com/", matches: false}, + // remember such regexp matches pattern anywhere in the URL unless anchored! + {pattern: `https?://(my|our)domain.com/`, url: "http://safedomain.com/?fakestring=http://mydomain.com/", matches: true}, + // same here, it matches this generic /admin pattern anywhere in the URL + {pattern: `/admin`, url: "https://theirdomain.com/admin", matches: true}, + {pattern: `/admin`, url: "https://theirdomain.com/admin/res/logo.jpg", matches: true}, + // can be anchored like an Nginx location block for PHP + {pattern: `\.php$`, url: "https://theirdomain.com/survey/index.php", matches: true}, + {pattern: `\.php$`, url: "https://theirdomain.com/survey/index.php/extra/path", matches: false}, + {pattern: `^https?://ourdomain.com/admin/.*`, url: "https://ourdomain.com/admin/", matches: true}, + {pattern: `^https?://ourdomain.com/admin/.*`, url: "https://ourdomain.com/admin/users", matches: true}, + {pattern: `^https?://ourdomain.com/admin/.*`, url: "https://ourdomain.com/admin/static/theme.css", matches: true}, + {pattern: `^https?://ourdomain.com/admin/.*`, url: "https://ourdomain.com/", matches: false}, + {pattern: `^https?://ourdomain.com/admin/.*`, url: "https://ourdomain.com/about-us/", matches: false}, + {pattern: `^https?://[^./]+.com/admin/.*`, url: "https://ourdomain.com/about-us/", matches: false}, + {pattern: `^https?://[^./]+.com/admin/.*`, url: "https://ourdomain.com/admin/", matches: true}, + {pattern: `^https?://[^./]+.com/admin/.*`, url: "https://google.com/about-us/", matches: false}, + {pattern: `^https?://[^./]+.com/admin/.*`, url: "https://google.com/admin/", matches: true}, + {pattern: `^https?://[^./]+.com/admin/.*`, url: "https://eff.org/admin/", matches: false}, + {pattern: `^https?://[^/]+/`, url: "https://www.google.com/", matches: true}, + } + + for _, c := range testCases { + if !assert.Equal(t, c.matches, URLMatchesRegexp(c.url, c.pattern)) { + t.Logf("URLMatchesRegexp(%v, %v) != %v", c.url, c.pattern, c.matches) + } + } +} diff --git a/internal/authorization/util.go b/internal/authorization/util.go deleted file mode 100644 index 80e6e29..0000000 --- a/internal/authorization/util.go +++ /dev/null @@ -1,11 +0,0 @@ -package authorization - -import ( - "strings" -) - -// PathMatches returns true if the URL matches the pattern containing an optional wildcard '*' character -func PathMatches(url, pattern string) bool { - return pattern == url || - (strings.HasSuffix(pattern, "*") && strings.HasPrefix(url, strings.TrimRight(pattern, "*"))) -} diff --git a/internal/authorization/util_test.go b/internal/authorization/util_test.go deleted file mode 100644 index f73b5d9..0000000 --- a/internal/authorization/util_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package authorization - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestPathMatches(t *testing.T) { - type test struct { - url string - pattern string - allow bool - } - - var testCases = []test{ - {url: "/", pattern: "/", allow: true}, - {url: "/xyz/subpath1/subpath2", pattern: "/xyz/*/subpath3", allow: false}, - {url: "/ops/portal/admin", pattern: "/ops/portal/", allow: false}, - {url: "/ops/portal/xyz.png", pattern: "/ops/portal/*.png", allow: false}, - {url: "/ops/portal/kibana/app/kibana", pattern: "/ops/portal/kibana/*", allow: true}, - {url: "/ops/portal/grafana/public/img/fav32.png", pattern: "/ops/portal/grafana/*", allow: true}, - {url: "/ops/portal/grafana/public/build/runtime.3932bda029d2299a9d96.js", pattern: "/ops/portal/grafana/*", allow: true}, - } - - for _, c := range testCases { - assert.Equal(t, c.allow, PathMatches(c.url, c.pattern)) - } -} diff --git a/internal/configuration/config.go b/internal/configuration/config.go index 3ab0eba..88978fd 100644 --- a/internal/configuration/config.go +++ b/internal/configuration/config.go @@ -51,13 +51,15 @@ type Config struct { Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"` LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"` Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"` - SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"` + SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing the cookie (required)" json:"-"` Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Only allow given email addresses, can be set multiple times"` EnableImpersonation bool `long:"enable-impersonation" env:"ENABLE_IMPERSONATION" description:"Indicates that impersonation headers should be set on successful auth"` + ForwardTokenHeaderName string `long:"forward-token-header-name" env:"FORWARD_TOKEN_HEADER_NAME" description:"Header name to forward the raw ID token in (won't forward token if empty)"` + ForwardTokenPrefix string `long:"forward-token-prefix" env:"FORWARD_TOKEN_PREFIX" default:"Bearer " description:"Prefix string to add before the forwarded ID token"` ServiceAccountTokenPath string `long:"service-account-token-path" env:"SERVICE_ACCOUNT_TOKEN_PATH" default:"/var/run/secrets/kubernetes.io/serviceaccount/token" description:"When impersonation is enabled, this token is passed via the Authorization header to the ingress. The user associated with the token must have impersonation privileges."` Rules map[string]*Rule `long:"rules.." description:"Rule definitions, param can be: \"action\" or \"rule\""` GroupClaimPrefix string `long:"group-claim-prefix" env:"GROUP_CLAIM_PREFIX" default:"oidc:" description:"prefix oidc group claims with this value"` - SessionKey string `long:"session-key" env:"SESSION_KEY" description:"A session key used to encrypt browser sessions"` + EncryptionKeyString string `long:"encryption-key" env:"ENCRYPTION_KEY" description:"Encryption key used to encrypt the cookie (required)" json:"-"` GroupsAttributeName string `long:"groups-attribute-name" env:"GROUPS_ATTRIBUTE_NAME" default:"groups" description:"Map the correct attribute that contain the user groups"` // RBAC @@ -73,23 +75,23 @@ type Config struct { // Filled during transformations OIDCContext context.Context OIDCProvider *oidc.Provider - Secret []byte `json:"-"` Lifetime time.Duration ServiceAccountToken string } -func NewGlobalConfig(args []string) *Config { +func NewGlobalConfig(args []string) (*Config, error) { var err error config, err = NewConfig(args) - if err != nil { - fmt.Printf("%+v\n", err) - os.Exit(1) - } - return config + return config, err } +// NewConfig loads config from provided args or uses os.Args if nil func NewConfig(args []string) (*Config, error) { + if args == nil && len(os.Args) > 0 { + args = os.Args[1:] + } + c := Config{ Rules: map[string]*Rule{}, } @@ -233,7 +235,7 @@ func (c *Config) Validate() { if len(c.Path) > 0 && c.Path[0] != '/' { c.Path = "/" + c.Path } - c.Secret = []byte(c.SecretString) + c.Lifetime = time.Second * time.Duration(c.LifetimeString) // get service account token @@ -244,13 +246,6 @@ func (c *Config) Validate() { } c.ServiceAccountToken = strings.TrimSuffix(string(t), "\n") } - - // RBAC - if c.EnableRBAC && len(c.SessionKey) != 16 && len(c.SessionKey) != 24 && len(c.SessionKey) != 32 { - // Gorilla sessions require encryption keys of specific length - // https://www.gorillatoolkit.org/pkg/sessions#NewCookieStore - log.Fatal("\"session-key\" must be 16, 24 or 32 bytes long to select AES-128, AES-192, or AES-256 modes") - } } // LoadOIDCProviderConfiguration loads the configuration of OpenID Connect provider @@ -270,11 +265,13 @@ func (c Config) String() string { return string(jsonConf) } +// Rule specifies an action for the rule type Rule struct { Action string Rule string } +// NewRule creates a new Rule instance func NewRule() *Rule { return &Rule{ Action: "auth", @@ -287,6 +284,7 @@ func (r *Rule) FormattedRule() string { return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(") } +// Validate validates the rule func (r *Rule) Validate() { if r.Action != "auth" && r.Action != "allow" { log.Fatal("invalid rule action, must be \"auth\" or \"allow\"") @@ -295,13 +293,16 @@ func (r *Rule) Validate() { // Legacy support for comma separated lists +// CommaSeparatedList flag value type CommaSeparatedList []string +// UnmarshalFlag unmarshals a comma-separated list from the flag value func (c *CommaSeparatedList) UnmarshalFlag(value string) error { *c = append(*c, strings.Split(value, ",")...) return nil } +// MarshalFlag marshals the comma-separated list to the flag value func (c *CommaSeparatedList) MarshalFlag() (string, error) { return strings.Join(*c, ","), nil } diff --git a/internal/configuration/config_test.go b/internal/configuration/config_test.go index 36eaa83..ace68e5 100644 --- a/internal/configuration/config_test.go +++ b/internal/configuration/config_test.go @@ -144,9 +144,6 @@ func TestConfigTransformation(t *testing.T) { assert.Equal("/_oauthpath", c.Path, "path should add slash to front") - assert.Equal("verysecret", c.SecretString) - assert.Equal([]byte("verysecret"), c.Secret, "secret should be converted to byte array") - assert.Equal(200, c.LifetimeString) assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration") } diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 0586f01..3f0be7a 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -3,6 +3,7 @@ package handlers import ( "fmt" "net/http" + "net/url" neturl "net/url" "strings" @@ -137,7 +138,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } // Validate cookie - email, err := s.authenticator.ValidateCookie(r, c) + id, err := s.authenticator.ValidateCookie(r, c) if err != nil { logger.Info(fmt.Sprintf("cookie validaton failure: %s", err.Error())) s.notAuthenticated(logger, w, r) @@ -145,15 +146,22 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } // Validate user - valid := s.authenticator.ValidateEmail(email) + valid := s.authenticator.ValidateEmail(id.Email) if !valid { logger.WithFields(logrus.Fields{ - "email": email, + "email": id.Email, }).Errorf("Invalid email") http.Error(w, "Not authorized", 401) return } + // Token forwarding requested now with no token stored in the session, reauth + if s.config.ForwardTokenHeaderName != "" && id.Token == "" { + logger.Info("re-auth forced because token forwarding enabled and no token stored") + s.notAuthenticated(logger, w, r) + return + } + // Authorize user groups, err := s.getGroupsFromSession(r) if err != nil { @@ -169,10 +177,19 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } if s.config.EnableRBAC && !s.authzIsBypassed(r) { - kubeUserInfo := s.makeKubeUserInfo(email, groups) + kubeUserInfo := s.makeKubeUserInfo(id.Email, groups) + + requestURL := authentication.GetRequestURL(r) + + targetURL, err := url.Parse(requestURL) + if err != nil { + logger.Errorf("unable to parse target URL %s: %v", requestURL, err) + http.Error(w, "Bad Gateway", 502) + return + } logger.Debugf("authorizing user: %s, groups: %s", kubeUserInfo.Name, kubeUserInfo.Groups) - authorized, err := s.authorizer.Authorize(kubeUserInfo, r.Method, r.URL.Path) + authorized, err := s.authorizer.Authorize(kubeUserInfo, r.Method, targetURL) if err != nil { logger.Errorf("error while authorizing %s: %v", kubeUserInfo, err) http.Error(w, "Bad Gateway", 502) @@ -180,30 +197,37 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { } if !authorized { - logger.Infof("user %s for is not authorized to `%s` in %s", kubeUserInfo.GetName(), r.Method, r.URL.Path) + logger.Infof("user %s is not authorized to `%s` in %s", kubeUserInfo.GetName(), r.Method, targetURL) + //TODO:k3a: consider some kind of re-auth to recheck for new groups http.Error(w, "Not Authorized", 401) return } - logger.Infof("user %s is authorized to `%s` in %s", kubeUserInfo.GetName(), r.Method, r.URL.Path) + + logger.Infof("user %s is authorized to `%s` in %s", kubeUserInfo.GetName(), r.Method, targetURL) } // Valid request - logger.Debugf("Allow request from %s", email) + logger.Debugf("Allow request from %s", id.Email) for _, headerName := range s.config.EmailHeaderNames { - w.Header().Set(headerName, email) + w.Header().Set(headerName, id.Email) } if s.config.EnableImpersonation { // Set impersonation headers - logger.Debug(fmt.Sprintf("setting authorization token and impersonation headers: email: %s, groups: %s", email, groups)) + logger.Debug(fmt.Sprintf("setting authorization token and impersonation headers: email: %s, groups: %s", id.Email, groups)) w.Header().Set("Authorization", fmt.Sprintf("Bearer %s", s.config.ServiceAccountToken)) - w.Header().Set(impersonateUserHeader, email) + w.Header().Set(impersonateUserHeader, id.Email) w.Header().Set(impersonateGroupHeader, "system:authenticated") for _, group := range groups { w.Header().Add(impersonateGroupHeader, fmt.Sprintf("%s%s", s.config.GroupClaimPrefix, group)) } w.Header().Set("Connection", cleanupConnectionHeader(w.Header().Get("Connection"))) } + + if s.config.ForwardTokenHeaderName != "" && id.Token != "" { + w.Header().Add(s.config.ForwardTokenHeaderName, s.config.ForwardTokenPrefix+id.Token) + } + w.WriteHeader(200) } } @@ -307,11 +331,15 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { return } - // Generate cookies email, ok := claims["email"] if ok { + token := "" + if s.config.ForwardTokenHeaderName != "" { + token = rawIDToken + } + // Generate cookies - http.SetCookie(w, s.authenticator.MakeIDCookie(r, email.(string))) + http.SetCookie(w, s.authenticator.MakeIDCookie(r, email.(string), token)) logger.WithFields(logrus.Fields{ "user": claims["email"].(string), }).Infof("generated auth cookie") @@ -332,11 +360,10 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Mapping groups groups := []string{} - gInterface, ok := claims[s.config.GroupsAttributeName].([]interface{}) + groupsClaim, ok := claims[s.config.GroupsAttributeName].([]interface{}) if ok { - groups = make([]string, len(gInterface)) - for i, v := range gInterface { - groups[i] = v.(string) + for _, g := range groupsClaim { + groups = append(groups, g.(string)) } } else { logger.Errorf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName) @@ -352,6 +379,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { http.Error(w, "Bad Gateway", 502) return } + // Redirect http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } @@ -463,7 +491,7 @@ func (s *Server) getGroupsFromSession(r *http.Request) ([]string, error) { // authzIsBypassed returns true if the request matches a bypass URI pattern func (s *Server) authzIsBypassed(r *http.Request) bool { for _, bypassURIPattern := range s.config.AuthZPassThrough { - if authorization.PathMatches(r.URL.Path, bypassURIPattern) { + if authorization.URLMatchesWildcardPattern(r.URL.Path, bypassURIPattern) { s.log.Infof("authorization is disabled for %s", r.URL.Path) return true } diff --git a/internal/handlers/server_test.go b/internal/handlers/server_test.go index 0a83c97..487fd01 100644 --- a/internal/handlers/server_test.go +++ b/internal/handlers/server_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "strings" "testing" "time" @@ -21,7 +20,20 @@ import ( intlog "github.com/mesosphere/traefik-forward-auth/internal/log" ) -// TODO: +var ( + testAuthKey1 = "4Zhbg4n22r4I8Kdg1gHMzRWQpT7TOArD" + testAuthKey2 = "HhaAG845dg9b16xKk8yiX+XoBhEAeHnQ" + testEncKey1 = "8jAnK6NGuzEuH3y13V+5Bm2jgp5bv8ku" + testEncKey2 = "FmvAqxzYy9ru0WaSU6SkLHP1ScoSVF/t" +) + +func newTestConfig(authKey, encKey string) *configuration.Config { + c, _ := configuration.NewConfig([]string{}) + c.SecretString = authKey + c.EncryptionKeyString = encKey + + return c +} /** * Setup @@ -80,29 +92,30 @@ func (f *fakeUserInfoStore) Save(r *http.Request, w http.ResponseWriter, info *v func TestServerAuthHandlerInvalid(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.AuthHost = "dex.example.com" config.Lifetime = time.Minute * time.Duration(config.LifetimeString) a := authentication.NewAuthenticator(config) // Should redirect vanilla request to login url - req := newDefaultHttpRequest("/foo") + req := newDefaultHTTPRequest("/foo") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "vanilla request should be redirected") // Should catch invalid cookie - req = newDefaultHttpRequest("/foo") - c := a.MakeIDCookie(req, "test@example.com") - parts := strings.Split(c.Value, "|") - c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2]) + req = newDefaultHTTPRequest("/foo") + // NOTE(jkoelker) `notAuthenticated` will redirect if it thinks the request is from a browser + req.Header.Set("Accept", "application/json") + c := a.MakeIDCookie(req, "test@example.com", "") + config = newTestConfig(testAuthKey2, testEncKey2) // new auth & encryption key! config.AuthHost = "" - config.OIDCProvider = &oidc.Provider{} res, _ = doHttpRequest(req, c, config) - assert.Equal(302, res.StatusCode, "invalid cookie should redirect") + assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised") // Should validate email - req = newDefaultHttpRequest("/foo") - c = a.MakeIDCookie(req, "test@example.com") + req = newDefaultHTTPRequest("/foo") + a = authentication.NewAuthenticator(config) + c = a.MakeIDCookie(req, "test@example.com", "") config.Domains = []string{"test.com"} res, _ = doHttpRequest(req, c, config) @@ -111,15 +124,15 @@ func TestServerAuthHandlerInvalid(t *testing.T) { func TestServerAuthHandlerExpired(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.Lifetime = time.Second * time.Duration(-1) config.Domains = []string{"test.com"} config.AuthHost = "potato.example.com" a := authentication.NewAuthenticator(config) // Should redirect expired cookie - req := newDefaultHttpRequest("/foo") - c := a.MakeIDCookie(req, "test@example.com") + req := newDefaultHTTPRequest("/foo") + c := a.MakeIDCookie(req, "test@example.com", "") res, _ := doHttpRequest(req, c, config) assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected") @@ -132,13 +145,13 @@ func TestServerAuthHandlerExpired(t *testing.T) { func TestServerAuthHandlerValid(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.Lifetime = time.Minute * time.Duration(config.LifetimeString) a := authentication.NewAuthenticator(config) // Should allow valid request email - req := newDefaultHttpRequest("/foo") - c := a.MakeIDCookie(req, "test@example.com") + req := newDefaultHTTPRequest("/foo") + c := a.MakeIDCookie(req, "test@example.com", "") config.Domains = []string{} @@ -154,7 +167,7 @@ func TestServerAuthHandlerValid(t *testing.T) { // TODO: OIDC exchanges need to be mocked for AuthCallback testing //func TestServerAuthCallback(t *testing.T) { // assert := assert.New(t) -// config, _ = NewConfig([]string{}) +// config = newTestConfig(testAuthKey1, testEncKey1) // config.AuthHost = "potato.example.com" // // // Setup token server @@ -168,18 +181,18 @@ func TestServerAuthHandlerValid(t *testing.T) { // defer userServer.Close() // // // Should pass auth response request to callback -// req := newDefaultHttpRequest("/_oauth") +// req := newDefaultHTTPRequest("/_oauth") // res, _ := doHttpRequest(req, nil) // assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised") // // // Should catch invalid csrf cookie -// req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") +// req = newDefaultHTTPRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") // c := MakeCSRFCookie(req, "nononononononononononononononono") // res, _ = doHttpRequest(req, c) // assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised") // // // Should redirect valid request -// req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") +// req = newDefaultHTTPRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") // c = MakeCSRFCookie(req, "12345678901234567890123456789012") // res, _ = doHttpRequest(req, c) // assert.Equal(307, res.StatusCode, "valid auth callback should be allowed") @@ -192,22 +205,22 @@ func TestServerAuthHandlerValid(t *testing.T) { func TestServerDefaultAction(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.AuthHost = "potato.example.com" - req := newDefaultHttpRequest("/random") + req := newDefaultHTTPRequest("/random") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "request should require auth with auth default handler") config.DefaultAction = "allow" config.AuthHost = "" - req = newDefaultHttpRequest("/random") + req = newDefaultHTTPRequest("/random") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request should be allowed with default handler") } func TestServerRouteHeaders(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.AuthHost = "potato.example.com" config.Rules = map[string]*configuration.Rule{ "1": { @@ -221,20 +234,20 @@ func TestServerRouteHeaders(t *testing.T) { } // Should block any request - req := newDefaultHttpRequest("/random") + req := newDefaultHTTPRequest("/random") req.Header.Add("X-Random", "hello") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") config.AuthHost = "" // Should allow matching - req = newDefaultHttpRequest("/api") + req = newDefaultHTTPRequest("/api") req.Header.Add("X-Test", "test123") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") // Should allow matching - req = newDefaultHttpRequest("/api") + req = newDefaultHTTPRequest("/api") req.Header.Add("X-Test", "test789") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") @@ -242,7 +255,7 @@ func TestServerRouteHeaders(t *testing.T) { func TestServerRouteHost(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.Rules = map[string]*configuration.Rule{ "1": { Action: "allow", @@ -257,25 +270,25 @@ func TestServerRouteHost(t *testing.T) { config.AuthHost = "potato.example.com" // Should block any request - req := newHttpRequest("GET", "https://example.com/", "/") + req := newHTTPRequest("GET", "https://example.com/", "/") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") config.AuthHost = "" // Should allow matching request - req = newHttpRequest("GET", "https://api.example.com/", "/") + req = newHTTPRequest("GET", "https://api.example.com/", "/") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") // Should allow matching request - req = newHttpRequest("GET", "https://sub8.example.com/", "/") + req = newHTTPRequest("GET", "https://sub8.example.com/", "/") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } func TestServerRouteMethod(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.Rules = map[string]*configuration.Rule{ "1": { Action: "allow", @@ -285,20 +298,20 @@ func TestServerRouteMethod(t *testing.T) { config.AuthHost = "potato.example.com" // Should block any request - req := newHttpRequest("GET", "https://example.com/", "/") + req := newHTTPRequest("GET", "https://example.com/", "/") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") config.AuthHost = "" // Should allow matching request - req = newHttpRequest("PUT", "https://example.com/", "/") + req = newHTTPRequest("PUT", "https://example.com/", "/") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } func TestServerRoutePath(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.Rules = map[string]*configuration.Rule{ "1": { Action: "allow", @@ -312,29 +325,29 @@ func TestServerRoutePath(t *testing.T) { config.AuthHost = "potato.example.com" // Should block any request - req := newDefaultHttpRequest("/random") + req := newDefaultHTTPRequest("/random") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") config.AuthHost = "" // Should allow /api request - req = newDefaultHttpRequest("/api") + req = newDefaultHTTPRequest("/api") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") // Should allow /private request - req = newDefaultHttpRequest("/private") + req = newDefaultHTTPRequest("/private") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") - req = newDefaultHttpRequest("/private/path") + req = newDefaultHTTPRequest("/private/path") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } func TestServerRouteQuery(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.Rules = map[string]*configuration.Rule{ "1": { Action: "allow", @@ -344,20 +357,20 @@ func TestServerRouteQuery(t *testing.T) { config.AuthHost = "potato.example.com" // Should block any request - req := newHttpRequest("GET", "https://example.com/", "/?q=no") + req := newHTTPRequest("GET", "https://example.com/", "/?q=no") res, _ := doHttpRequest(req, nil, config) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") config.AuthHost = "" // Should allow matching request - req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123") + req = newHTTPRequest("GET", "https://api.example.com/", "/?q=test123") res, _ = doHttpRequest(req, nil, config) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } func TestAuthzDisabled(t *testing.T) { assert := assert.New(t) - config, _ := configuration.NewConfig([]string{}) + config := newTestConfig(testAuthKey1, testEncKey1) config.EnableRBAC = true config.AuthZPassThrough = []string{"/authz/passthru", "/authz/passthru/*"} @@ -435,11 +448,8 @@ func doHttpRequest(r *http.Request, c *http.Cookie, config *configuration.Config return res, string(body) } -func newDefaultHttpRequest(uri string) *http.Request { - return newHttpRequest("", "http://example.com/", uri) -} - -func newHttpRequest(method, dest, uri string) *http.Request { +// newHTTPRequest creates a mocked HTTP request from Traefik (with X-Forwarded-* headers) +func newHTTPRequest(method, dest, uri string) *http.Request { r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil) p, _ := url.Parse(dest) r.Header.Add("X-Forwarded-Method", method) @@ -450,6 +460,11 @@ func newHttpRequest(method, dest, uri string) *http.Request { return r } +// newDefaultHTTPRequest creates a mocked request from Traefik for http://example.com with no HTTP method +func newDefaultHTTPRequest(uri string) *http.Request { + return newHTTPRequest("", "http://example.com/", uri) +} + func qsDiff(t *testing.T, one, two url.Values) []string { errs := make([]string, 0) for k := range one { From 976876ff03ee03c3f49959e38aa1b7e99d9614b2 Mon Sep 17 00:00:00 2001 From: Mario Hros Date: Sun, 10 May 2020 22:42:32 +0200 Subject: [PATCH 5/9] get rid of global config variable cherry-pick b0dd328 --- cmd/main.go | 19 ++++++++++--------- internal/authentication/auth.go | 13 ++++--------- internal/configuration/config.go | 21 +++++++++++---------- internal/handlers/server.go | 1 - 4 files changed, 25 insertions(+), 29 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index ac656bc..ac86204 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,27 +1,28 @@ package main import ( + "fmt" + "net/http" + "os" + "time" + + "github.com/gorilla/sessions" + k8s "k8s.io/client-go/kubernetes" + "github.com/mesosphere/traefik-forward-auth/internal/api/storage/v1alpha1" "github.com/mesosphere/traefik-forward-auth/internal/authentication" "github.com/mesosphere/traefik-forward-auth/internal/configuration" "github.com/mesosphere/traefik-forward-auth/internal/handlers" kubernetes "github.com/mesosphere/traefik-forward-auth/internal/kubernetes" + logger "github.com/mesosphere/traefik-forward-auth/internal/log" "github.com/mesosphere/traefik-forward-auth/internal/storage" "github.com/mesosphere/traefik-forward-auth/internal/storage/cluster" - "net/http" - "os" - "time" - "fmt" - - "github.com/gorilla/sessions" - logger "github.com/mesosphere/traefik-forward-auth/internal/log" - k8s "k8s.io/client-go/kubernetes" ) // Main func main() { // Parse options - config, err := configuration.NewGlobalConfig(os.Args[1:]) + config, err := configuration.NewConfig(os.Args[1:]) if err != nil { fmt.Printf("%+v\n", err) os.Exit(1) diff --git a/internal/authentication/auth.go b/internal/authentication/auth.go index a9de465..a01c62f 100644 --- a/internal/authentication/auth.go +++ b/internal/authentication/auth.go @@ -19,7 +19,7 @@ type Authenticator struct { } func NewAuthenticator(config *configuration.Config) *Authenticator { - cookieMaxAge := int(config.Lifetime / time.Second) + cookieMaxAge := config.CookieMaxAge() hashKey := []byte(config.SecretString) blockKey := []byte(config.EncryptionKeyString) @@ -102,7 +102,7 @@ func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) { // MakeIDCookie creates an auth cookie func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string) *http.Cookie { - expires := a.cookieExpiry() + expires := a.config.CookieExpiry() data := &ID{ Email: email, Token: token, @@ -126,7 +126,7 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string // MakeNameCookie creates a name cookie func (a *Authenticator) MakeNameCookie(r *http.Request, name string) *http.Cookie { - expires := a.cookieExpiry() + expires := a.config.CookieExpiry() return &http.Cookie{ Name: a.config.UserCookieName, @@ -148,7 +148,7 @@ func (a *Authenticator) MakeCSRFCookie(r *http.Request, nonce string) *http.Cook Domain: a.csrfCookieDomain(r), HttpOnly: true, Secure: !a.config.InsecureCookie, - Expires: a.cookieExpiry(), + Expires: a.config.CookieExpiry(), } } @@ -240,11 +240,6 @@ func (a *Authenticator) matchCookieDomains(domain string) (bool, string) { return false, p[0] } -// Get cookie expirary -func (a *Authenticator) cookieExpiry() time.Time { - return time.Now().Local().Add(a.config.Lifetime) -} - // Utility methods // getRequestSchemeHost returns scheme://host part of the request diff --git a/internal/configuration/config.go b/internal/configuration/config.go index 88978fd..e9aa0e1 100644 --- a/internal/configuration/config.go +++ b/internal/configuration/config.go @@ -24,9 +24,7 @@ import ( ) var ( - // TODO(jr): Get rid of the global config object - config *Config - log logrus.FieldLogger + log logrus.FieldLogger ) // Config holds app configuration @@ -79,13 +77,6 @@ type Config struct { ServiceAccountToken string } -func NewGlobalConfig(args []string) (*Config, error) { - var err error - config, err = NewConfig(args) - - return config, err -} - // NewConfig loads config from provided args or uses os.Args if nil func NewConfig(args []string) (*Config, error) { if args == nil && len(os.Args) > 0 { @@ -260,6 +251,16 @@ func (c *Config) LoadOIDCProviderConfiguration() error { return nil } +// CookieExpiry returns the cookie expiration time (Now() + configured Lifetime) +func (c Config) CookieExpiry() time.Time { + return time.Now().Local().Add(c.Lifetime) +} + +// CookieMaxAge returns number of seconds to cookie expiration (configured Lifetime converted to seconds) +func (c Config) CookieMaxAge() int { + return int(c.Lifetime / time.Second) +} + func (c Config) String() string { jsonConf, _ := json.Marshal(c) return string(jsonConf) diff --git a/internal/handlers/server.go b/internal/handlers/server.go index 3f0be7a..b7ede7d 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -369,7 +369,6 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { logger.Errorf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName) } - logger.Printf("creating claims session with groups: %v", groups) if err := s.userinfo.Save(r, w, &v1alpha1.UserInfo{ Username: name.(string), Email: email.(string), From 0b2f10c6c19513ef203dd63117a2328bf23beba3 Mon Sep 17 00:00:00 2001 From: Mario Hros Date: Thu, 14 May 2020 16:06:59 +0200 Subject: [PATCH 6/9] README.md updated for 3.0.0 cherry-pick 78e279f --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 77cf0e4..be389b4 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,10 @@ This is a partial rewrite to support generic OIDC Providers that provide [OpenID [`noelcatt/traefik-forward-auth`](https://github.com/noelcatt/traefik-forward-auth) and [`funkypenguin/traefik-forward-auth`](https://github.com/funkypenguin/traefik-forward-auth) also made [`thomseddon/traefik-forward-auth`](https://github.com/thomseddon/traefik-forward-auth) apply to generic OIDC, but they are now based on an older version which does not support rules and also require the UserInfo endpoint to be supported. +This version optionally implements RBAC within Kuberbetes by using `ClusterRole` and `ClusterRoleBinding`. It extends from the original Kubernetes usage as it also allows specifying full URLs (including a scheme and domain) within `nonResourceURLs` attribute of `ClusterRole`. And unlike the original behavior, `*` wildcard character matches within one path component only. There is a special globstar `**` to match within multiple paths (inspired by Bash, Python and JS libraries). + +The raw id-token received from OIDC provider can optionally be passed upstream via a custom header. + ## Differences to the original The instructions for [`thomseddon/traefik-forward-auth`](https://github.com/thomseddon/traefik-forward-auth) are useful, keeping in mind that this version: @@ -19,3 +23,10 @@ The instructions for [`thomseddon/traefik-forward-auth`](https://github.com/thom - Returns 401 rather than redirect to OIDC Login if an unauthenticated request is not for HTML (e.g. AJAX calls, images). - Sends a username cookie as well - If `auth-host` is set and `cookie-domains` is not set, traefik-forward-auth will redirect any requests using other hostnames to `auth-host`. Set `auth-host` to the OIDC redirect host to ensure that use of the IP or other DNS names will be redirected and get a suitable cookie. + +## Upgrading from 2.x version to 3.0 (Breaking Changes): + +- config `session-key` (`SESSION_KEY` env) is now called `encryption-key` (`ENCRYPTION_KEY` env) and is `REQUIRED` +- config `groups-session-name` (`GROUPS_SESSION_NAME`) is deprecated as both email and groups are part of the single cookie `cookie-name` (`COOKIE_NAME` env) +- character `*` in existing RBAC rules now works within one path component only, so a single `*` has to be replaced with `**` to match the previous behavior (whether to use `*` or `**` is up to the person writing those rules) + From 08f8b4d9d125823ace15392993090d9877187b1f Mon Sep 17 00:00:00 2001 From: Jonathan Giddy Date: Fri, 13 Aug 2021 09:17:13 +0100 Subject: [PATCH 7/9] Use typical error patterns `makeSessionCookie` turns its err value into a nil, losing its info. Also make fatal problems use Error and non-fatal problems use Warn. cherry-pick 3e4e6f9 --- internal/authentication/auth.go | 8 +++++--- internal/authentication/auth_test.go | 14 +++++++++----- internal/handlers/server.go | 26 ++++++++++++++++---------- internal/handlers/server_test.go | 12 ++++++++---- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/internal/authentication/auth.go b/internal/authentication/auth.go index a01c62f..702b8e2 100644 --- a/internal/authentication/auth.go +++ b/internal/authentication/auth.go @@ -101,7 +101,7 @@ func (a *Authenticator) useAuthDomain(r *http.Request) (bool, string) { // Cookie methods // MakeIDCookie creates an auth cookie -func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string) *http.Cookie { +func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string) (*http.Cookie, error) { expires := a.config.CookieExpiry() data := &ID{ Email: email, @@ -110,10 +110,10 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string encoded, err := a.secureCookie.Encode(a.config.CookieName, data) if err != nil { - return nil + return nil, err } - return &http.Cookie{ + cookie := &http.Cookie{ Name: a.config.CookieName, Value: encoded, Path: "/", @@ -122,6 +122,8 @@ func (a *Authenticator) MakeIDCookie(r *http.Request, email string, token string Secure: !a.config.InsecureCookie, Expires: expires, } + + return cookie, nil } // MakeNameCookie creates a name cookie diff --git a/internal/authentication/auth_test.go b/internal/authentication/auth_test.go index b1cb63e..92d278b 100644 --- a/internal/authentication/auth_test.go +++ b/internal/authentication/auth_test.go @@ -52,7 +52,8 @@ func TestAuthValidateCookie(t *testing.T) { // Should catch expired config.Lifetime = time.Second * time.Duration(-1) a = NewAuthenticator(config) - c = a.MakeIDCookie(r, "test@test.com", "") + c, err = a.MakeIDCookie(r, "test@test.com", "") + assert.Nil(err) _, err = a.ValidateCookie(r, c) if assert.Error(err) { assert.Equal("securecookie: expired timestamp", err.Error()) @@ -61,7 +62,8 @@ func TestAuthValidateCookie(t *testing.T) { // Should accept valid cookie config.Lifetime = time.Second * time.Duration(10) a = NewAuthenticator(config) - c = a.MakeIDCookie(r, "test@test.com", "") + c, err = a.MakeIDCookie(r, "test@test.com", "") + assert.Nil(err) id, err := a.ValidateCookie(r, c) assert.Nil(err, "valid request should not return an error") assert.Equal("test@test.com", id.Email, "valid request should return user email") @@ -124,10 +126,11 @@ func TestAuthMakeCookie(t *testing.T) { r, _ := http.NewRequest("GET", "http://app.example.com", nil) r.Header.Add("X-Forwarded-Host", "app.example.com") - c := a.MakeIDCookie(r, "test@example.com", "") + c, err := a.MakeIDCookie(r, "test@example.com", "") + assert.Nil(err) assert.Equal("_forward_auth", c.Name) assert.Greater(len(c.Value), 18, "encoded securecookie should be longer") - _, err := a.ValidateCookie(r, c) + _, err = a.ValidateCookie(r, c) assert.Nil(err, "should generate valid cookie") assert.Equal("/", c.Path) assert.Equal("app.example.com", c.Domain) @@ -138,7 +141,8 @@ func TestAuthMakeCookie(t *testing.T) { config.CookieName = "testname" config.InsecureCookie = true - c = a.MakeIDCookie(r, "test@example.com", "") + c, err = a.MakeIDCookie(r, "test@example.com", "") + assert.Nil(err) assert.Equal("testname", c.Name) assert.False(c.Secure) } diff --git a/internal/handlers/server.go b/internal/handlers/server.go index b7ede7d..d3f4cc1 100644 --- a/internal/handlers/server.go +++ b/internal/handlers/server.go @@ -264,7 +264,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Check for CSRF cookie c, err := r.Cookie(s.config.CSRFCookieName) if err != nil { - logger.Warnf("missing CSRF cookie: %v", err) + logger.Errorf("missing CSRF cookie: %v", err) http.Error(w, "Not authorized", 401) return } @@ -272,7 +272,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Validate state valid, redirect, err := authentication.ValidateCSRFCookie(r, c) if !valid { - logger.Warnf("error validating CSRF cookie: %v", err) + logger.Errorf("error validating CSRF cookie: %v", err) http.Error(w, "Not authorized", 401) return } @@ -301,7 +301,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Exchange code for token oauth2Token, err := oauth2Config.Exchange(s.config.OIDCContext, r.URL.Query().Get("code")) if err != nil { - logger.Warnf("failed to exchange token: %v", err) + logger.Errorf("failed to exchange token: %v", err) http.Error(w, "Bad Gateway", 502) return } @@ -309,7 +309,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Extract the ID Token from OAuth2 token. rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { - logger.Warnf("missing ID token: %v", err) + logger.Error("missing ID token") http.Error(w, "Bad Gateway", 502) return } @@ -318,7 +318,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { verifier := provider.Verifier(&oidc.Config{ClientID: s.config.ClientID}) idToken, err := verifier.Verify(s.config.OIDCContext, rawIDToken) if err != nil { - logger.Warnf("failed to verify token: %v", err) + logger.Errorf("failed to verify token: %v", err) http.Error(w, "Bad Gateway", 502) return } @@ -326,7 +326,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Extract custom claims var claims map[string]interface{} if err := idToken.Claims(&claims); err != nil { - logger.Warnf("failed to extract claims: %v", err) + logger.Errorf("failed to extract claims: %v", err) http.Error(w, "Bad Gateway", 502) return } @@ -339,12 +339,18 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Generate cookies - http.SetCookie(w, s.authenticator.MakeIDCookie(r, email.(string), token)) + c, err := s.authenticator.MakeIDCookie(r, email.(string), token) + if err != nil { + logger.Errorf("error generating secure session cookie: %v", err) + http.Error(w, "Bad Gateway", 502) + return + } + http.SetCookie(w, c) logger.WithFields(logrus.Fields{ "user": claims["email"].(string), }).Infof("generated auth cookie") } else { - logger.Errorf("no email claim present in the ID token") + logger.Warn("no email claim present in the ID token") } // If name in null, empty or whitespace, use email address for name @@ -356,7 +362,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { http.SetCookie(w, s.authenticator.MakeNameCookie(r, name.(string))) logger.WithFields(logrus.Fields{ "name": name.(string), - }).Infof("generated name cookie") + }).Info("generated name cookie") // Mapping groups groups := []string{} @@ -366,7 +372,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { groups = append(groups, g.(string)) } } else { - logger.Errorf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName) + logger.Warnf("failed to get groups claim from the ID token (GroupsAttributeName: %s)", s.config.GroupsAttributeName) } if err := s.userinfo.Save(r, w, &v1alpha1.UserInfo{ diff --git a/internal/handlers/server_test.go b/internal/handlers/server_test.go index 487fd01..11c3861 100644 --- a/internal/handlers/server_test.go +++ b/internal/handlers/server_test.go @@ -105,7 +105,8 @@ func TestServerAuthHandlerInvalid(t *testing.T) { req = newDefaultHTTPRequest("/foo") // NOTE(jkoelker) `notAuthenticated` will redirect if it thinks the request is from a browser req.Header.Set("Accept", "application/json") - c := a.MakeIDCookie(req, "test@example.com", "") + c, err := a.MakeIDCookie(req, "test@example.com", "") + assert.Nil(err) config = newTestConfig(testAuthKey2, testEncKey2) // new auth & encryption key! config.AuthHost = "" @@ -115,7 +116,8 @@ func TestServerAuthHandlerInvalid(t *testing.T) { // Should validate email req = newDefaultHTTPRequest("/foo") a = authentication.NewAuthenticator(config) - c = a.MakeIDCookie(req, "test@example.com", "") + c, err = a.MakeIDCookie(req, "test@example.com", "") + assert.Nil(err) config.Domains = []string{"test.com"} res, _ = doHttpRequest(req, c, config) @@ -132,7 +134,8 @@ func TestServerAuthHandlerExpired(t *testing.T) { // Should redirect expired cookie req := newDefaultHTTPRequest("/foo") - c := a.MakeIDCookie(req, "test@example.com", "") + c, err := a.MakeIDCookie(req, "test@example.com", "") + assert.Nil(err) res, _ := doHttpRequest(req, c, config) assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected") @@ -151,7 +154,8 @@ func TestServerAuthHandlerValid(t *testing.T) { // Should allow valid request email req := newDefaultHTTPRequest("/foo") - c := a.MakeIDCookie(req, "test@example.com", "") + c, err := a.MakeIDCookie(req, "test@example.com", "") + assert.Nil(err) config.Domains = []string{} From 20216299fe169a385bf8ac12f0831cafa03155aa Mon Sep 17 00:00:00 2001 From: Jared Rodriguez Date: Mon, 25 Oct 2021 16:58:41 -0500 Subject: [PATCH 8/9] gate v3 url pattern matching in feature flag cherry-pick 3996bce --- internal/authorization/rbac/rbac_test.go | 5 ++++ internal/authorization/urlpatterns.go | 12 ++++++---- internal/authorization/urlpatterns_test.go | 27 +++++++++++++++++++++- internal/configuration/config.go | 7 ++++++ internal/features/vars.go | 13 +++++++++++ 5 files changed, 58 insertions(+), 6 deletions(-) create mode 100644 internal/features/vars.go diff --git a/internal/authorization/rbac/rbac_test.go b/internal/authorization/rbac/rbac_test.go index c384b84..ca8c9af 100644 --- a/internal/authorization/rbac/rbac_test.go +++ b/internal/authorization/rbac/rbac_test.go @@ -1,6 +1,7 @@ package rbac import ( + "github.com/mesosphere/traefik-forward-auth/internal/features" "net/url" "testing" @@ -319,3 +320,7 @@ func TestRBACAuthorizer_AuthorizePatternTypes(t *testing.T) { } } } + +func init() { + features.EnableV3URLPatternMatchin() +} diff --git a/internal/authorization/urlpatterns.go b/internal/authorization/urlpatterns.go index 0eb81ea..c2795d1 100644 --- a/internal/authorization/urlpatterns.go +++ b/internal/authorization/urlpatterns.go @@ -5,6 +5,8 @@ import ( "regexp" "strings" "sync" + + "github.com/mesosphere/traefik-forward-auth/internal/features" ) var ( @@ -118,9 +120,9 @@ func URLMatchesRegexp(url, regex string) bool { // URLMatchesWildcardPattern returns true if the URL matches the pattern containing optional wildcard '*' characters func URLMatchesWildcardPattern(url, pattern string) bool { - // original implementation: - // return pattern == url || - // (strings.HasSuffix(pattern, "*") && strings.HasPrefix(url, strings.TrimRight(pattern, "*"))) - - return globalRECache.MatchString(url, pattern, true) + if features.V3URLPatternMatchingEnabled() { + return globalRECache.MatchString(url, pattern, true) + } else { + return pattern == url || (strings.HasSuffix(pattern, "*") && strings.HasPrefix(url, strings.TrimRight(pattern, "*"))) + } } diff --git a/internal/authorization/urlpatterns_test.go b/internal/authorization/urlpatterns_test.go index 6cc40bb..dc89819 100644 --- a/internal/authorization/urlpatterns_test.go +++ b/internal/authorization/urlpatterns_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/mesosphere/traefik-forward-auth/internal/features" ) func TestWildcardMatches(t *testing.T) { @@ -28,7 +30,7 @@ func TestWildcardMatches(t *testing.T) { {pattern: "/ops/portal/grafana/**", url: "/ops/portal/grafana/public/img/fav32.png", matches: true}, {pattern: "/ops/portal/grafana/**", url: "/ops/portal/grafana/public/build/runtime.3932bda029d2299a9d96.js", matches: true}, } - + features.EnableV3URLPatternMatchin() for _, c := range testCases { if !assert.Equal(t, c.matches, URLMatchesWildcardPattern(c.url, c.pattern)) { t.Logf("URLMatchesWildcardPattern(%v, %v) != %v", c.url, c.pattern, c.matches) @@ -74,6 +76,29 @@ func TestRegexpMatches(t *testing.T) { {pattern: `^https?://[^/]+/`, url: "https://www.google.com/", matches: true}, } + features.EnableV3URLPatternMatchin() + for _, c := range testCases { + if !assert.Equal(t, c.matches, URLMatchesRegexp(c.url, c.pattern)) { + t.Logf("URLMatchesRegexp(%v, %v) != %v", c.url, c.pattern, c.matches) + } + } +} + +func TestOldPreV3Matching(t *testing.T) { + type test struct { + pattern string + url string + matches bool + } + + var testCases = []test{ + {pattern: ``, url: "", matches: false}, + {pattern: ``, url: "/", matches: false}, + {pattern: `/`, url: "", matches: false}, + {pattern: `/`, url: "/", matches: true}, + {pattern: `/admin/*`, url: "/admin/sub1/sub2/index.html", matches: true}, + {pattern: `/admin'`, url: "/admin/sub1/sub2/index.html", matches: false}, + } for _, c := range testCases { if !assert.Equal(t, c.matches, URLMatchesRegexp(c.url, c.pattern)) { t.Logf("URLMatchesRegexp(%v, %v) != %v", c.url, c.pattern, c.matches) diff --git a/internal/configuration/config.go b/internal/configuration/config.go index e9aa0e1..d4332bb 100644 --- a/internal/configuration/config.go +++ b/internal/configuration/config.go @@ -19,6 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/thomseddon/go-flags" + "github.com/mesosphere/traefik-forward-auth/internal/features" internallog "github.com/mesosphere/traefik-forward-auth/internal/log" "github.com/mesosphere/traefik-forward-auth/internal/util" ) @@ -75,6 +76,9 @@ type Config struct { OIDCProvider *oidc.Provider Lifetime time.Duration ServiceAccountToken string + + // Flags + EnableV3URLPatternMatching bool `long:"enable-v3-url-pattern-matching" env:"ENABLE_V3_URL_PATTERN_MATCHING" description:"Specifies weather to use v3 URL pattern matching as implemented in this commit: https://github.com/mesosphere/traefik-forward-auth/commit/36c3eee4c9fa262064848d4ddaca6652b96763b5"` } // NewConfig loads config from provided args or uses os.Args if nil @@ -237,6 +241,9 @@ func (c *Config) Validate() { } c.ServiceAccountToken = strings.TrimSuffix(string(t), "\n") } + if c.EnableV3URLPatternMatching { + features.EnableV3URLPatternMatchin() + } } // LoadOIDCProviderConfiguration loads the configuration of OpenID Connect provider diff --git a/internal/features/vars.go b/internal/features/vars.go new file mode 100644 index 0000000..f752ca5 --- /dev/null +++ b/internal/features/vars.go @@ -0,0 +1,13 @@ +package features + +var ( + v3URLPatternMatching bool +) + +func EnableV3URLPatternMatchin() { + v3URLPatternMatching = true +} + +func V3URLPatternMatchingEnabled() bool { + return v3URLPatternMatching +} From 6a275ec4e135640f2822d67be9096e4b3dc689c7 Mon Sep 17 00:00:00 2001 From: Jared Rodriguez Date: Mon, 25 Oct 2021 17:02:22 -0500 Subject: [PATCH 9/9] fix import order (cherry picked from commit 8334159bbf3b4a865fa24bc44dc94250dc8af8e2) --- internal/authorization/rbac/rbac_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/authorization/rbac/rbac_test.go b/internal/authorization/rbac/rbac_test.go index ca8c9af..131d272 100644 --- a/internal/authorization/rbac/rbac_test.go +++ b/internal/authorization/rbac/rbac_test.go @@ -1,7 +1,6 @@ package rbac import ( - "github.com/mesosphere/traefik-forward-auth/internal/features" "net/url" "testing" @@ -13,6 +12,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "github.com/mesosphere/traefik-forward-auth/internal/authorization" + "github.com/mesosphere/traefik-forward-auth/internal/features" ) const (