Skip to content

Commit

Permalink
Option to extract client connection user id from http header (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored Oct 29, 2023
1 parent 2194ff4 commit 540cda5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
25 changes: 25 additions & 0 deletions internal/middleware/user_header_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package middleware

import (
"net/http"

"github.com/centrifugal/centrifuge"
)

// UserHeaderAuth is a middleware that extracts the value of user ID from the specific header
// and sets connection credentials.
func UserHeaderAuth(userHeaderName string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := r.Header.Get(userHeaderName)
if userID != "" {
ctx := centrifuge.SetCredentials(r.Context(), &centrifuge.Credentials{
UserID: userID,
})
r = r.WithContext(ctx)
}
// Call the next handler
next.ServeHTTP(w, r)
})
}
}
46 changes: 46 additions & 0 deletions internal/middleware/user_header_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/centrifugal/centrifuge"
"github.com/stretchr/testify/require"
)

func userHeaderAuthTestHandler(t *testing.T, userMustBeSet bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cred, ok := centrifuge.GetCredentials(r.Context())
if userMustBeSet {
require.True(t, ok, "credentials should be set")
require.Equal(t, "123", cred.UserID, "user ID should be set correctly")
} else {
require.False(t, ok)
}
_, _ = w.Write([]byte("OK"))
})
}

func TestUserHeaderAuthWithUserID(t *testing.T) {
middleware := UserHeaderAuth("X-User-ID")
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-User-ID", "123")
rr := httptest.NewRecorder()

handler := middleware(userHeaderAuthTestHandler(t, true))
handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code, "status code should be 200 OK")
}

func TestUserHeaderAuthWithoutUserID(t *testing.T) {
middleware := UserHeaderAuth("X-User-ID")
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()

handler := middleware(userHeaderAuthTestHandler(t, false))
handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code, "status code should be 200 OK")
}
6 changes: 6 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ var defaults = map[string]any{
"client_insecure_skip_token_signature_verify": false,
"api_insecure": false,

"client_user_id_http_header": "",

"token_hmac_secret_key": "",
"token_rsa_public_key": "",
"token_ecdsa_public_key": "",
Expand Down Expand Up @@ -2806,6 +2808,10 @@ func Mux(n *centrifuge.Node, ruleContainer *rule.Container, apiExecutor *api.Exe
connLimitMW := middleware.NewConnLimit(n, ruleContainer)
connMiddlewares = append(connMiddlewares, connLimitMW.Middleware)
}
userIDHTTPHeader := v.GetString("client_user_id_http_header")
if userIDHTTPHeader != "" {
connMiddlewares = append(connMiddlewares, middleware.UserHeaderAuth(userIDHTTPHeader))
}
if keepHeadersInContext {
connMiddlewares = append(connMiddlewares, middleware.HeadersToContext)
}
Expand Down

0 comments on commit 540cda5

Please sign in to comment.