diff --git a/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/DuplicationStrategy.java b/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/DuplicationStrategy.java index 9c592a8..63677a3 100644 --- a/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/DuplicationStrategy.java +++ b/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/DuplicationStrategy.java @@ -30,27 +30,31 @@ public interface DuplicationStrategy { * * @param clientAddress the client address * @param requestPacket the request packet + * @param requestPacketBytes the request packet bytes * * @return the duplication strategy result */ - Result handleRequest(InetSocketAddress clientAddress, Packet requestPacket); + Result handleRequest(InetSocketAddress clientAddress, Packet requestPacket, byte[] requestPacketBytes); /** * Handles a response. The response will be saved to the cache. * * @param clientAddress the client address * @param requestPacket the request packet + * @param requestPacketBytes the request packet bytes * @param responsePacket the response packet */ - void handleResponse(InetSocketAddress clientAddress, Packet requestPacket, Packet responsePacket); + void handleResponse(InetSocketAddress clientAddress, Packet requestPacket, byte[] requestPacketBytes, + Packet responsePacket); /** * Unhandles a request. Removes an in-progress request from the cache. * * @param clientAddress the client address * @param requestPacket the request packet + * @param requestPacketBytes the request packet bytes */ - void unhandleRequest(InetSocketAddress clientAddress, Packet requestPacket); + void unhandleRequest(InetSocketAddress clientAddress, Packet requestPacket, byte[] requestPacketBytes); /** * A duplication strategy result contains a State and a possible response packet. When {@link #getState()} is diff --git a/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/TimedDuplicationStrategy.java b/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/TimedDuplicationStrategy.java index 6f64efa..27b0c64 100644 --- a/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/TimedDuplicationStrategy.java +++ b/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/TimedDuplicationStrategy.java @@ -30,6 +30,7 @@ /** * An implementation of {@link DuplicationStrategy} with a configurable time-to-live value for the cached responses. + * Caches all requests using the entire request packet bytes. */ public final class TimedDuplicationStrategy implements DuplicationStrategy { @@ -43,7 +44,9 @@ public TimedDuplicationStrategy(Duration ttlDuration) { } @Override - public synchronized Result handleRequest(InetSocketAddress clientAddress, Packet requestPacket) { + public synchronized Result handleRequest(InetSocketAddress clientAddress, Packet requestPacket, + byte[] requestPacketBytes) + { long currentEpochMillis = Instant.now().toEpochMilli(); // Remove the expired cache entries @@ -61,7 +64,7 @@ public synchronized Result handleRequest(InetSocketAddress clientAddress, Packet } } - CacheKey cacheKey = new CacheKey(clientAddress, requestPacket.getReceivedFields().getIdentifier()); + CacheKey cacheKey = new CacheKey(clientAddress, requestPacketBytes); if (!cacheMap.containsKey(cacheKey)) { // It's a new, unseen request; add it to the cache @@ -72,13 +75,6 @@ public synchronized Result handleRequest(InetSocketAddress clientAddress, Packet CacheValue cacheValue = cacheMap.get(cacheKey); - if (!Arrays.equals(cacheValue.requestPacket.getReceivedFields().getAuthenticator(), - requestPacket.getReceivedFields().getAuthenticator())) - { - // If the request authenticator field is different from what we have in the cache then we must clear it - cacheMap.replace(cacheKey, new CacheValue(currentEpochMillis, requestPacket)); - } - if (cacheValue.responsePacket != null) { return new Result(State.CACHED_RESPONSE, cacheValue.responsePacket); } @@ -89,9 +85,9 @@ public synchronized Result handleRequest(InetSocketAddress clientAddress, Packet @Override public synchronized void handleResponse(InetSocketAddress clientAddress, Packet requestPacket, - Packet responsePacket) + byte[] requestPacketBytes, Packet responsePacket) { - CacheKey cacheKey = new CacheKey(clientAddress, requestPacket.getReceivedFields().getIdentifier()); + CacheKey cacheKey = new CacheKey(clientAddress, requestPacketBytes); if (!cacheMap.containsKey(cacheKey)) { return; @@ -110,8 +106,10 @@ public synchronized void handleResponse(InetSocketAddress clientAddress, Packet } @Override - public synchronized void unhandleRequest(InetSocketAddress clientAddress, Packet requestPacket) { - CacheKey cacheKey = new CacheKey(clientAddress, requestPacket.getReceivedFields().getIdentifier()); + public synchronized void unhandleRequest(InetSocketAddress clientAddress, Packet requestPacket, + byte[] requestPacketBytes) + { + CacheKey cacheKey = new CacheKey(clientAddress, requestPacketBytes); if (!cacheMap.containsKey(cacheKey)) { return; @@ -133,16 +131,16 @@ private static class CacheKey { private final InetSocketAddress clientAddress; - private final int identifier; + private final byte[] requestBytes; - private CacheKey(InetSocketAddress clientAddress, int identifier) { + private CacheKey(InetSocketAddress clientAddress, byte[] requestBytes) { this.clientAddress = clientAddress; - this.identifier = identifier; + this.requestBytes = requestBytes; } @Override public int hashCode() { - return Objects.hash(clientAddress, identifier); + return Objects.hash(clientAddress, Arrays.hashCode(requestBytes)); } @Override @@ -157,7 +155,8 @@ public boolean equals(Object obj) { CacheKey cacheKey = (CacheKey) obj; - return Objects.equals(clientAddress, cacheKey.clientAddress) && identifier == cacheKey.identifier; + return Objects.equals(clientAddress, cacheKey.clientAddress) + && Arrays.equals(requestBytes, cacheKey.requestBytes); } } diff --git a/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/servers/UdpRadiusServer.java b/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/servers/UdpRadiusServer.java index 4238c93..005c35a 100644 --- a/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/servers/UdpRadiusServer.java +++ b/aaa4j-radius-server/src/main/java/org/aaa4j/radius/server/servers/UdpRadiusServer.java @@ -261,7 +261,8 @@ private void runHandler(InetSocketAddress clientAddress, ByteBuffer inByteBuffer Packet responsePacket = null; // Check the duplication cache for a cached response - Result result = udpRadiusServer.duplicationStrategy.handleRequest(clientAddress, requestPacket); + Result result = + udpRadiusServer.duplicationStrategy.handleRequest(clientAddress, requestPacket, inBytes); switch (result.getState()) { case NEW_REQUEST: @@ -272,11 +273,11 @@ private void runHandler(InetSocketAddress clientAddress, ByteBuffer inByteBuffer if (responsePacket != null) { udpRadiusServer.duplicationStrategy.handleResponse(clientAddress, requestPacket, - responsePacket); + inBytes, responsePacket); } } catch (Exception e) { - udpRadiusServer.duplicationStrategy.unhandleRequest(clientAddress, requestPacket); + udpRadiusServer.duplicationStrategy.unhandleRequest(clientAddress, requestPacket, inBytes); throw e; }