Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rate limit per service #3903

Draft
wants to merge 23 commits into
base: v3.x.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
ZWE_CONFIGS_APIML_SERVICE_ADDITIONALREGISTRATION_0_DISCOVERYSERVICEURLS: https://discovery-service-2:10011/eureka
SERVER_MAX_HTTP_REQUEST_HEADER_SIZE: 16348
SERVER_WEBSOCKET_REQUESTBUFFERSIZE: 16348
APIML_GATEWAY_ROUTING_SERVICESTOLIMITREQUESTRATE: discoverableclient
APIML_GATEWAY_SERVICESTOLIMITREQUESTRATE: discoverableclient
APIML_GATEWAY_ROUTING_COOKIENAMEFORRATELIMIT: apimlAuthenticationToken
zaas-service:
image: ghcr.io/balhar-jakub/zaas-service:${{ github.run_id }}-${{ github.run_number }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private EurekaMetadataDefinition() {
public static final String SERVICE_EXTERNAL_URL = "apiml.service.externalUrl";
public static final String SERVICE_SUPPORTING_CLIENT_CERT_FORWARDING = "apiml.service.supportClientCertForwarding";
public static final String ENABLE_URL_ENCODED_CHARACTERS = "apiml.enableUrlEncodedCharacters";
public static final String APPLY_RATE_LIMITER_FILTER = "gateway.applyRateLimiterFilter";
public static final String APIML_ID = "apiml.service.apimlId";
public static final String REGISTRATION_TYPE = "apiml.registrationType";

Expand Down
8 changes: 4 additions & 4 deletions gateway-package/src/main/resources/bin/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ _BPX_JOBNAME=${ZWE_zowe_job_prefix}${GATEWAY_CODE} ${JAVA_BIN_DIR}java \
-Dapiml.gateway.cachePeriodSec=${ZWE_configs_apiml_gateway_registry_cachePeriodSec:-120} \
-Dapiml.gateway.registry.enabled=${ZWE_configs_apiml_gateway_registry_enabled:-false} \
-Dapiml.gateway.maxSimultaneousRequests=${ZWE_configs_gateway_registry_maxSimultaneousRequests:-20} \
-Dapiml.gateway.rateLimiterCapacity=${ZWE_configs_apiml_gateway_routing_rateLimiterCapacity:-20} \
-Dapiml.gateway.rateLimiterTokens=${ZWE_configs_apiml_gateway_routing_rateLimiterTokens:-20} \
-Dapiml.gateway.rateLimiterRefillDuration=${ZWE_configs_apiml_gateway_routing_rateLimiterRefillDuration:-1} \
-Dapiml.gateway.servicesToLimitRequestRate=${ZWE_configs_apiml_gateway_routing_servicesToLimitRequestRate:-} \
-Dapiml.gateway.rateLimiterCapacity=${ZWE_configs_apiml_gateway_rateLimiterCapacity:-20} \
-Dapiml.gateway.rateLimiterTokens=${ZWE_configs_apiml_gateway_rateLimiterTokens:-20} \
-Dapiml.gateway.rateLimiterRefillDuration=${ZWE_configs_apiml_gateway_rateLimiterRefillDuration:-1} \
-Dapiml.gateway.servicesToLimitRequestRate=${ZWE_configs_apiml_gateway_servicesToLimitRequestRate:-} \
-Dapiml.gateway.cookieNameForRateLimit=${cookieName:-apimlAuthenticationToken} \
-Dapiml.gateway.registry.metadata-key-allow-list=${ZWE_configs_gateway_registry_metadataKeyAllowList:-} \
-Dapiml.gateway.refresh-interval-ms=${ZWE_configs_gateway_registry_refreshIntervalMs:-30000} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ public List<FilterDefinition> filters() {
retryFilter.addArg("series", "");
filters.add(retryFilter);

FilterDefinition rateLimiterFilter = new FilterDefinition();
rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory");
filters.add(rateLimiterFilter);

for (String headerName : ignoredHeadersWhenCorsEnabled.split(",")) {
FilterDefinition removeHeaders = new FilterDefinition();
removeHeaders.setName("RemoveRequestHeader");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
public class InMemoryRateLimiter implements RateLimiter<InMemoryRateLimiter.Config> {

private final Map<String, Bucket> cache = new ConcurrentHashMap<>();
@Value("${apiml.gateway.routing.rateLimiterCapacity:20}")
int capacity;
@Value("${apiml.gateway.routing.rateLimiterTokens:20}")
int tokens;
@Value("${apiml.gateway.routing.rateLimiterRefillDuration:1}")

@Value("${apiml.gateway.rateLimiterCapacity:20}")
Integer capacity;

@Value("${apiml.gateway.rateLimiterTokens:20}")
Integer tokens;

@Value("${apiml.gateway.rateLimiterRefillDuration:1}")
Integer refillDuration;

@Override
Expand All @@ -55,6 +58,12 @@ private Map<String, String> getHeaders(Bucket bucket) {
return headers;
}

public void setParameters(Integer capacity, Integer tokens, Integer refillDuration) {
this.capacity = (capacity != null) ? capacity : this.capacity;
this.tokens = (tokens != null) ? tokens : this.tokens;
this.refillDuration = (refillDuration != null) ? refillDuration : this.refillDuration;;
}

@Override
public Map<String, Config> getConfig() {
Config defaultConfig = new Config();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.Setter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpStatus;
Expand All @@ -34,13 +33,10 @@ public class InMemoryRateLimiterFilterFactory extends AbstractGatewayFilterFacto
@InjectApimlLogger
private ApimlLogger apimlLog = ApimlLogger.empty();

private final InMemoryRateLimiter rateLimiter;
private InMemoryRateLimiter rateLimiter;

private final KeyResolver keyResolver;

@Value("${apiml.gateway.routing.servicesToLimitRequestRate:-}")
List<String> serviceIds;

private final ObjectMapper mapper;

private final MessageService messageService;
Expand All @@ -55,17 +51,18 @@ public InMemoryRateLimiterFilterFactory(InMemoryRateLimiter rateLimiter, KeyReso

@Override
public GatewayFilter apply(Config config) {
this.rateLimiter.setParameters(config.capacity, config.tokens, config.refillDuration);
return (exchange, chain) -> {
List<PathContainer.Element> pathElements = exchange.getRequest().getPath().elements();
String requestPath = (!pathElements.isEmpty() && pathElements.size() > 1) ? pathElements.get(1).value() : null;
if (requestPath == null || !serviceIds.contains(requestPath)) {
if (requestPath == null) {
return chain.filter(exchange);
}
return keyResolver.resolve(exchange).flatMap(key -> {
if (key.isEmpty()) {
return chain.filter(exchange);
}
return rateLimiter.isAllowed(config.getRouteId(), key).flatMap(response -> {
return rateLimiter.isAllowed(requestPath, key).flatMap(response -> {
if (response.isAllowed()) {
return chain.filter(exchange);
} else {
Expand All @@ -87,9 +84,8 @@ public GatewayFilter apply(Config config) {
@Getter
@Setter
public static class Config {
private String routeId;
private Integer capacity;
private Integer tokens;
private Integer refillIntervalSeconds;
private Integer refillDuration;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public class RouteLocator implements RouteDefinitionLocator {
@Value("${apiml.service.forwardClientCertEnabled:false}")
private boolean forwardingClientCertEnabled;

@Value("${apiml.gateway.servicesToLimitRequestRate:-}")
List<String> servicesToLimitRequestRate;

private final ApplicationContext context;

private final CorsUtils corsUtils;
Expand Down Expand Up @@ -140,6 +143,21 @@ List<FilterDefinition> getPostRoutingFilters(ServiceInstance serviceInstance) {
serviceRelated.add(forbidEncodedCharactersFilter);
}

if (Optional.ofNullable(serviceInstance.getMetadata().get(APPLY_RATE_LIMITER_FILTER))
.map(Boolean::parseBoolean)
.orElse(false)) {
FilterDefinition rateLimiterFilter = new FilterDefinition();
rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory");
rateLimiterFilter.addArg("capacity", serviceInstance.getMetadata().get("gateway.rateLimiterCapacity"));
rateLimiterFilter.addArg("tokens", serviceInstance.getMetadata().get("gateway.rateLimiterTokens"));
rateLimiterFilter.addArg("refillDuration", serviceInstance.getMetadata().get("gateway.refillDuration"));
serviceRelated.add(rateLimiterFilter);
} else if (servicesToLimitRequestRate != null && servicesToLimitRequestRate.contains(serviceInstance.getServiceId().toLowerCase())) {
FilterDefinition rateLimiterFilter = new FilterDefinition();
rateLimiterFilter.setName("InMemoryRateLimiterFilterFactory");
serviceRelated.add(rateLimiterFilter);
}

FilterDefinition pageRedirectionFilter = new FilterDefinition();
pageRedirectionFilter.setName("PageRedirectionFilterFactory");
pageRedirectionFilter.addArg("serviceId", serviceInstance.getServiceId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -40,7 +39,6 @@ public class InMemoryRateLimiterFilterFactoryTest {
private InMemoryRateLimiterFilterFactory filterFactory;
private ServerWebExchange exchange;
private GatewayFilterChain chain;
private String serviceId;
private MockServerHttpRequest request;
private InMemoryRateLimiterFilterFactory.Config config;
private MessageService messageService;
Expand All @@ -49,19 +47,16 @@ public class InMemoryRateLimiterFilterFactoryTest {

@BeforeEach
public void setUp() {
serviceId = "testService";
rateLimiter = mock(InMemoryRateLimiter.class);
keyResolver = mock(KeyResolver.class);
messageService = mock(MessageService.class);
message = mock(Message.class);
objectMapper = mock(ObjectMapper.class);
filterFactory = new InMemoryRateLimiterFilterFactory(rateLimiter, keyResolver,objectMapper, messageService);
filterFactory.serviceIds = List.of(serviceId);
request = MockServerHttpRequest.get("/" + serviceId).build();
request = MockServerHttpRequest.get("/" + "serviceId").build();
exchange = MockServerWebExchange.from(request);
chain = mock(GatewayFilterChain.class);
config = mock(InMemoryRateLimiterFilterFactory.Config.class);
when(config.getRouteId()).thenReturn("testRoute");
}

@Test
Expand Down Expand Up @@ -115,20 +110,6 @@ public void apply_shouldAllowRequest_whenKeyIsNull() {
verify(chain, times(1)).filter(exchange);
}

@Test
public void apply_shouldAllowRequest_whenServiceIdDoesNotMatch() {
String nonMatchingId = "nonMatchingId";
when(keyResolver.resolve(exchange)).thenReturn(Mono.just("testKey"));
request = MockServerHttpRequest.get("/" + nonMatchingId).build();
exchange = MockServerWebExchange.from(request);
when(chain.filter(any(ServerWebExchange.class))).thenReturn(Mono.empty());

StepVerifier.create(filterFactory.apply(config).filter(exchange, chain))
.expectComplete()
.verify();
verify(chain, times(1)).filter(exchange);
}

@Test
public void apply_shouldAllowRequest_whenServiceIdEmpty() {
when(keyResolver.resolve(exchange)).thenReturn(Mono.just("testKey"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,24 @@ public void testNewConfig() {
assertEquals(rateLimiter.tokens, config.getTokens(), "Config tokens should match the rate limiter tokens");
assertEquals(rateLimiter.refillDuration, config.getRefillDuration(), "Config refill duration should match the rate limiter refill duration");
}

@Test
public void setNonNullParametersTest() {
Integer newCapacity = 20;
Integer newTokens = 20;
Integer newRefillDuration = 2;
rateLimiter.setParameters(newCapacity, newTokens, newRefillDuration);
assertEquals(newCapacity, rateLimiter.capacity);
assertEquals(newTokens, rateLimiter.tokens);
assertEquals(newRefillDuration, rateLimiter.refillDuration);
}

@Test
public void setParametersWithNullValuesTest() {
Integer newCapacity = 30;
rateLimiter.setParameters(newCapacity, null, null);
assertEquals(newCapacity, rateLimiter.capacity);
assertNotNull(rateLimiter.tokens);
assertNotNull(rateLimiter.refillDuration);
}
}
Loading
Loading