Skip to content

Commit

Permalink
feat: allow configuration of custom cluster cookie name
Browse files Browse the repository at this point in the history
Fixes #145
  • Loading branch information
mcollovati committed Nov 15, 2024
1 parent 31c2a13 commit a50a665
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ protected FilterRegistration.Dynamic addRegistration(
}

@Bean
PushSessionTracker pushSendListener(
SessionSerializer sessionSerializer) {
return new PushSessionTracker(sessionSerializer);
PushSessionTracker pushSendListener(SessionSerializer sessionSerializer,
KubernetesKitProperties properties) {
return new PushSessionTracker(sessionSerializer,
properties.getClusterKeyCookieName());
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.springframework.boot.context.properties.ConfigurationProperties;

import com.vaadin.kubernetes.starter.sessiontracker.CurrentKey;
import com.vaadin.kubernetes.starter.sessiontracker.SameSite;

/**
Expand All @@ -32,6 +33,11 @@ public class KubernetesKitProperties {
*/
private boolean autoConfigure = true;

/**
* The name of the distributed storage session key cookie.
*/
private String clusterKeyCookieName = CurrentKey.COOKIE_NAME;

/**
* Value of the distributed storage session key cookie's SameSite attribute.
*/
Expand Down Expand Up @@ -64,6 +70,25 @@ public void setAutoConfigure(boolean autoConfigure) {
this.autoConfigure = autoConfigure;
}

/**
* Gets the name of the distributed storage session key cookie.
*
* @return the name of the distributed storage session key cookie
*/
public String getClusterKeyCookieName() {
return clusterKeyCookieName;
}

/**
* Sets the name of the distributed storage session key cookie.
*
* @param clusterKeyCookieName
* the name of the distributed storage session key cookie
*/
public void setClusterKeyCookieName(String clusterKeyCookieName) {
this.clusterKeyCookieName = clusterKeyCookieName;
}

/**
* Gets the distributed storage session key cookie's SameSite attribute
* value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;

import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
Expand All @@ -39,17 +40,49 @@ private SessionTrackerCookie() {
* the HTTP request.
* @param response
* the HTTP response.
* @param cookieConsumer
* function to apply custom setting to the cluster key cookie.
* @deprecated use
* {@link #setIfNeeded(HttpSession, HttpServletRequest, HttpServletResponse, String, Consumer)}
* providing the cluster cookie name instead.
*/
@Deprecated(since = "2.4", forRemoval = true)
public static void setIfNeeded(HttpSession session,
HttpServletRequest request, HttpServletResponse response,
Consumer<Cookie> cookieConsumer) {
Optional<Cookie> clusterKeyCookie = getCookie(request);
setIfNeeded(session, request, response, CurrentKey.COOKIE_NAME,
cookieConsumer);
}

/**
* Sets the distributed storage session key on the HTTP session.
*
* If the Cookie does not exist, a new key is generated and the Cookie is
* created and added to the HTTP response.
*
* @param session
* the HTTP session.
* @param request
* the HTTP request.
* @param response
* the HTTP response.
* @param cookieName
* the name for the cluster cookie.
* @param cookieConsumer
* function to apply custom setting to the cluster key cookie.
*/
public static void setIfNeeded(HttpSession session,
HttpServletRequest request, HttpServletResponse response,
String cookieName, Consumer<Cookie> cookieConsumer) {
cookieName = Objects.requireNonNullElse(cookieName,
CurrentKey.COOKIE_NAME);
Optional<Cookie> clusterKeyCookie = getCookie(request, cookieName);
if (clusterKeyCookie.isEmpty()) {
String clusterKey = UUID.randomUUID().toString();
if (session != null) {
session.setAttribute(CurrentKey.COOKIE_NAME, clusterKey);
}
Cookie cookie = new Cookie(CurrentKey.COOKIE_NAME, clusterKey);
Cookie cookie = new Cookie(cookieName, clusterKey);
cookieConsumer.accept(cookie);
response.addCookie(cookie);
} else if (session != null
Expand All @@ -72,13 +105,13 @@ public static Optional<String> getFromSession(HttpSession session) {
(String) session.getAttribute(CurrentKey.COOKIE_NAME));
}

private static Optional<Cookie> getCookie(HttpServletRequest request) {
private static Optional<Cookie> getCookie(HttpServletRequest request,
String cookieName) {
Cookie[] cookies = request.getCookies();
if (cookies == null) {
return Optional.empty();
}
return Stream.of(cookies)
.filter(c -> c.getName().equals(CurrentKey.COOKIE_NAME))
return Stream.of(cookies).filter(c -> c.getName().equals(cookieName))
.findFirst();
}

Expand All @@ -91,9 +124,29 @@ private static Optional<Cookie> getCookie(HttpServletRequest request) {
* @return the current distributed storage session key wrapped into an
* {@link Optional}, or an empty Optional if the Cookie does not
* exist.
* @deprecated use {@link #getValue(HttpServletRequest, String)} providing
* the cluster cookie name instead.
*/
@Deprecated(since = "2.4", forRemoval = true)
public static Optional<String> getValue(HttpServletRequest request) {
return getCookie(request).map(Cookie::getValue);
return getValue(request, CurrentKey.COOKIE_NAME);
}

/**
* Gets the value of the current distributed storage session key from the
* Cookie.
*
* @param request
* the HTTP request.
* @param cookieName
* the name of the cluster key cookie.
* @return the current distributed storage session key wrapped into an
* {@link Optional}, or an empty Optional if the Cookie does not
* exist.
*/
public static Optional<String> getValue(HttpServletRequest request,
String cookieName) {
return getCookie(request, cookieName).map(Cookie::getValue);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public SessionTrackerFilter(SessionSerializer sessionSerializer,
protected void doFilter(HttpServletRequest request,
HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
SessionTrackerCookie.getValue(request).ifPresent(key -> {
String cookieName = properties.getClusterKeyCookieName();
SessionTrackerCookie.getValue(request, cookieName).ifPresent(key -> {
CurrentKey.set(key);
if (request.getSession(false) == null) {
// Cluster key set but no session, create one, so it can be
Expand All @@ -74,7 +75,7 @@ protected void doFilter(HttpServletRequest request,
HttpSession session = request.getSession(false);

SessionTrackerCookie.setIfNeeded(session, request, response,
cookieConsumer(request));
cookieName, cookieConsumer(request));
super.doFilter(request, response, chain);

if (session != null && request.isRequestedSessionIdValid()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,22 @@ public class PushSessionTracker implements PushSendListener {
private final SessionSerializer sessionSerializer;

private Predicate<String> activeSessionChecker = id -> true;
private String clusterCookieName;

/**
* @deprecated use {@link #PushSessionTracker(SessionSerializer, String)}
* instead
*/
@Deprecated(forRemoval = true)
public PushSessionTracker(SessionSerializer sessionSerializer) {
this.sessionSerializer = sessionSerializer;
this.clusterCookieName = CurrentKey.COOKIE_NAME;
}

public PushSessionTracker(SessionSerializer sessionSerializer,
String clusterCookieName) {
this.sessionSerializer = sessionSerializer;
this.clusterCookieName = clusterCookieName;
}

/**
Expand Down Expand Up @@ -106,7 +119,8 @@ private Optional<String> tryGetSerializationKey(
if (key == null) {
try {
key = SessionTrackerCookie
.getValue(resource.getRequest().wrappedRequest())
.getValue(resource.getRequest().wrappedRequest(),
clusterCookieName)
.orElse(null);
} catch (Exception ex) {
getLogger().debug("Cannot get serialization key from request",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package com.vaadin.kubernetes.starter.sessiontracker;

import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;

import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;

import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -22,6 +23,8 @@

public class SessionTrackerCookieTest {

public static final String CLUSTER_COOKIE_NAME = "MY_CLUSTER_COOKIE";

@Test
void setIfNeeded_nullCookies_attributeIsSetAndCookieIsConfigured() {
HttpSession session = mock(HttpSession.class);
Expand All @@ -33,7 +36,7 @@ void setIfNeeded_nullCookies_attributeIsSetAndCookieIsConfigured() {
Consumer.class);

SessionTrackerCookie.setIfNeeded(session, request, response,
cookieConsumer);
CLUSTER_COOKIE_NAME, cookieConsumer);

verify(session).setAttribute(eq(CurrentKey.COOKIE_NAME), anyString());
verify(cookieConsumer).accept(any());
Expand All @@ -51,7 +54,7 @@ void setIfNeeded_emptyCookies_attributeIsSetAndCookieIsConfigured() {
Consumer.class);

SessionTrackerCookie.setIfNeeded(session, request, response,
cookieConsumer);
CLUSTER_COOKIE_NAME, cookieConsumer);

verify(session).setAttribute(eq(CurrentKey.COOKIE_NAME), anyString());
verify(cookieConsumer).accept(any());
Expand All @@ -65,14 +68,14 @@ void setIfNeeded_nullSessionAttribute_attributeIsSet() {
HttpSession session = mock(HttpSession.class);
when(session.getAttribute(anyString())).thenReturn(null);
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getCookies()).thenReturn(new Cookie[] {
new Cookie(CurrentKey.COOKIE_NAME, clusterKey) });
when(request.getCookies()).thenReturn(
new Cookie[] { new Cookie(CLUSTER_COOKIE_NAME, clusterKey) });
HttpServletResponse response = mock(HttpServletResponse.class);
Consumer<Cookie> cookieConsumer = (Cookie cookie) -> {
};

SessionTrackerCookie.setIfNeeded(session, request, response,
cookieConsumer);
CLUSTER_COOKIE_NAME, cookieConsumer);

verify(session).setAttribute(eq(CurrentKey.COOKIE_NAME),
eq(clusterKey));
Expand All @@ -86,14 +89,14 @@ void setIfNeeded_nonNullSessionAttribute_attributeIsNotSet() {
HttpSession session = mock(HttpSession.class);
when(session.getAttribute(anyString())).thenReturn("foo");
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getCookies()).thenReturn(new Cookie[] {
new Cookie(CurrentKey.COOKIE_NAME, clusterKey) });
when(request.getCookies()).thenReturn(
new Cookie[] { new Cookie(CLUSTER_COOKIE_NAME, clusterKey) });
HttpServletResponse response = mock(HttpServletResponse.class);
Consumer<Cookie> cookieConsumer = (Cookie cookie) -> {
};

SessionTrackerCookie.setIfNeeded(session, request, response,
cookieConsumer);
CLUSTER_COOKIE_NAME, cookieConsumer);

verify(session, never()).setAttribute(any(), any());
verify(response, never()).addCookie(any());
Expand All @@ -113,7 +116,8 @@ void getValue_nullCookies_emptyIsReturned() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getCookies()).thenReturn(null);

Optional<String> value = SessionTrackerCookie.getValue(request);
Optional<String> value = SessionTrackerCookie.getValue(request,
CLUSTER_COOKIE_NAME);

verify(request).getCookies();
assertEquals(Optional.empty(), value);
Expand All @@ -124,7 +128,8 @@ void getValue_emptyCookies_emptyIsReturned() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getCookies()).thenReturn(new Cookie[0]);

Optional<String> value = SessionTrackerCookie.getValue(request);
Optional<String> value = SessionTrackerCookie.getValue(request,
CLUSTER_COOKIE_NAME);

verify(request).getCookies();
assertTrue(value.isEmpty());
Expand All @@ -135,17 +140,17 @@ void getValue_valueIsReturned() {
String clusterKey = UUID.randomUUID().toString();

HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getCookies()).thenReturn(new Cookie[] {
new Cookie(CurrentKey.COOKIE_NAME, clusterKey) });
when(request.getCookies()).thenReturn(
new Cookie[] { new Cookie(CLUSTER_COOKIE_NAME, clusterKey) });

Optional<String> value = SessionTrackerCookie.getValue(request);
Optional<String> value = SessionTrackerCookie.getValue(request,
CLUSTER_COOKIE_NAME);

verify(request).getCookies();
assertTrue(value.isPresent());
assertEquals(clusterKey, value.get());
}


@Test
void setIfNeeded_nullCookiesAndSession_cookieIsConfigured() {
HttpServletRequest request = mock(HttpServletRequest.class);
Expand All @@ -156,11 +161,10 @@ void setIfNeeded_nullCookiesAndSession_cookieIsConfigured() {
Consumer.class);

SessionTrackerCookie.setIfNeeded(null, request, response,
cookieConsumer);
CLUSTER_COOKIE_NAME, cookieConsumer);

verify(cookieConsumer).accept(any());
verify(response).addCookie(any());
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -143,7 +145,8 @@ void validHttpSession_cookieConsumer_configuresCookie() throws Exception {
SessionTrackerCookie.class)) {
filter.doFilter(request, response, filterChain);
mockedStatic.verify(() -> SessionTrackerCookie.setIfNeeded(any(),
any(), any(), cookieConsumerArgumentCaptor.capture()));
any(), any(), anyString(),
cookieConsumerArgumentCaptor.capture()));
Consumer<Cookie> cookieConsumer = cookieConsumerArgumentCaptor
.getValue();
cookieConsumer.accept(cookie);
Expand Down
Loading

0 comments on commit a50a665

Please sign in to comment.