From f9a537be64f1ab3c1bbe1b57e1a5cdb64d07eb16 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 19 Sep 2024 16:33:52 +0200 Subject: [PATCH] Add new rate limit headers --- apps/opik-backend/config.yml | 2 +- .../opik/infrastructure/auth/AuthFilter.java | 3 +- .../infrastructure/auth/RequestContext.java | 5 + .../ratelimit/RateLimitInterceptor.java | 19 +- .../ratelimit/RateLimitResponseFilter.java | 36 ++++ .../ratelimit/RateLimitService.java | 4 + .../redis/RedisRateLimitService.java | 87 +++------ .../ratelimit/RateLimitE2ETest.java | 175 ++++++++++++++---- 8 files changed, 221 insertions(+), 110 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index d70b40d0a1..25e11fd1b7 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -70,4 +70,4 @@ rateLimit: enabled: ${RATE_LIMIT_ENABLED:-false} generalEvents: limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-5000} - durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION:-1} \ No newline at end of file + durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION_SEC:-1} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java index add7358835..e5e2540497 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java @@ -23,6 +23,7 @@ public class AuthFilter implements ContainerRequestFilter { private final AuthService authService; + private final jakarta.inject.Provider requestContext; @Override public void filter(ContainerRequestContext context) throws IOException { @@ -36,7 +37,7 @@ public void filter(ContainerRequestContext context) throws IOException { if (Pattern.matches("/v1/private/.*", requestUri.getPath())) { authService.authenticate(headers, sessionToken); } - + requestContext.get().setHeaders(context.getHeaders()); } HttpHeaders getHttpHeaders(ContainerRequestContext context) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java index 64f964e827..a9f174a2ae 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java @@ -1,6 +1,7 @@ package com.comet.opik.infrastructure.auth; import com.google.inject.servlet.RequestScoped; +import jakarta.ws.rs.core.MultivaluedMap; import lombok.Data; @RequestScoped @@ -13,9 +14,13 @@ public class RequestContext { public static final String SESSION_COOKIE = "sessionToken"; public static final String WORKSPACE_ID = "workspaceId"; public static final String API_KEY = "apiKey"; + public static final String USER_LIMIT = "Opik-User-Limit"; + public static final String USER_REMAINING_LIMIT = "Opik-User-Remaining-Limit"; + public static final String USER_LIMIT_REMAINING_TTL = "Opik-User-Remaining-Limit-TTL-Millis"; private String userName; private String workspaceName; private String workspaceId; private String apiKey; + private MultivaluedMap headers; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java index 538cbb9a83..7de4f45006 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -9,9 +9,9 @@ import lombok.extern.slf4j.Slf4j; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; -import reactor.core.publisher.Mono; import java.lang.reflect.Method; +import java.util.List; @Slf4j @RequiredArgsConstructor @@ -55,6 +55,7 @@ public Object invoke(MethodInvocation invocation) throws Throwable { .block(); if (Boolean.TRUE.equals(limitExceeded)) { + setLimitHeaders(apiKey, bucket); throw new ClientErrorException("Too Many Requests", 429); } @@ -63,12 +64,20 @@ public Object invoke(MethodInvocation invocation) throws Throwable { } catch (Exception ex) { decreaseLimitInCaseOfError(bucket, events); throw ex; + } finally { + setLimitHeaders(apiKey, bucket); } } return invocation.proceed(); } + private void setLimitHeaders(String apiKey, String bucket) { + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(RateLimited.GENERAL_EVENTS)); + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block())); + requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block())); + } + private Object getParameters(MethodInvocation method) { for (int i = 0; i < method.getArguments().length; i++) { @@ -82,11 +91,9 @@ private Object getParameters(MethodInvocation method) { private void decreaseLimitInCaseOfError(String bucket, Long events) { try { - Mono.deferContextual(context -> { - String apiKey = context.get(RequestContext.API_KEY); - - return rateLimitService.get().decrement(apiKey, bucket, events); - }).subscribe(); + String apiKey = requestContext.get().getApiKey(); + rateLimitService.get().decrement(apiKey, bucket, events) + .subscribe(); } catch (Exception ex) { log.warn("Failed to decrement rate limit", ex); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java new file mode 100644 index 0000000000..186a5bb432 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java @@ -0,0 +1,36 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.auth.RequestContext; +import jakarta.inject.Inject; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerResponseContext; +import jakarta.ws.rs.container.ContainerResponseFilter; +import jakarta.ws.rs.ext.Provider; +import lombok.RequiredArgsConstructor; + +import java.io.IOException; +import java.util.List; + +@Provider +@RequiredArgsConstructor(onConstructor_ = @Inject) +public class RateLimitResponseFilter implements ContainerResponseFilter { + + @Override + public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) throws IOException { + List userLimit = getValueFromHeader(requestContext, RequestContext.USER_LIMIT); + List remainingLimit = getValueFromHeader(requestContext, RequestContext.USER_REMAINING_LIMIT); + List remainingTtl = getValueFromHeader(requestContext, RequestContext.USER_LIMIT_REMAINING_TTL); + + responseContext.getHeaders().put(RequestContext.USER_LIMIT, userLimit); + responseContext.getHeaders().put(RequestContext.USER_REMAINING_LIMIT, remainingLimit); + responseContext.getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, remainingTtl); + } + + private List getValueFromHeader(ContainerRequestContext requestContext, String key) { + return requestContext.getHeaders().getOrDefault(key, List.of()) + .stream() + .map(Object.class::cast) + .toList(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java index 20879e7450..1b9cbb9014 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -8,4 +8,8 @@ Mono isLimitExceeded(String apiKey, long events, String bucketName, lon long limitDurationInSeconds); Mono decrement(String apiKey, String bucketName, long events); + + Mono availableEvents(String apiKey, String bucketName); + + Mono getRemainingTTL(String apiKey, String bucket); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java index b2b02389d2..30f9dbaca9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -2,58 +2,17 @@ import com.comet.opik.infrastructure.ratelimit.RateLimitService; import lombok.NonNull; -import org.redisson.api.RScript; +import org.redisson.api.RRateLimiterReactive; +import org.redisson.api.RateIntervalUnit; +import org.redisson.api.RateType; import org.redisson.api.RedissonReactiveClient; -import org.redisson.client.codec.StringCodec; import reactor.core.publisher.Mono; -import java.util.List; +import java.time.Duration; public class RedisRateLimitService implements RateLimitService { - private static final String LUA_SCRIPT_ADD = """ - local current = redis.call('GET', KEYS[1]) - if not current then - current = 0 - else - current = tonumber(current) - end - - local limit = tonumber(ARGV[1]) - local increment = tonumber(ARGV[2]) - local ttl = tonumber(ARGV[3]) - - if (current + increment) > limit then - return 0 -- Failure, limit exceeded - else - redis.call('INCRBY', KEYS[1], increment) - - if redis.call('TTL', KEYS[1]) == -1 then - redis.call('EXPIRE', KEYS[1], ttl, 'NX') - end - - return 1 -- Success, increment done - end - """; - - private static final String LUA_SCRIPT_DECR = """ - local current = redis.call('GET', KEYS[1]) - if not current then - current = 0 - else - current = tonumber(current) - end - - local decrement = tonumber(ARGV[1]) - - if (current - decrement) > 0 then - redis.call('DECRBY', KEYS[1], decrement) - else - redis.call('SET', KEYS[1], 0) - end - - return 'OK' - """; + private static final String KEY = "%s:%s"; private final RedissonReactiveClient redisClient; @@ -65,30 +24,30 @@ public RedisRateLimitService(RedissonReactiveClient redisClient) { public Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, long limitDurationInSeconds) { - Mono eval = redisClient.getScript(StringCodec.INSTANCE).eval( - RScript.Mode.READ_WRITE, - LUA_SCRIPT_ADD, - RScript.ReturnType.INTEGER, - List.of(bucketName + ":" + apiKey), - limit, - events, - limitDurationInSeconds); + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); - return eval.map(result -> result == 0); + return rateLimit.trySetRate(RateType.OVERALL, limit, limitDurationInSeconds, RateIntervalUnit.SECONDS) + .then(Mono.defer(() -> rateLimit.expireIfNotSet(Duration.ofSeconds(limitDurationInSeconds)))) + .then(Mono.defer(() -> rateLimit.tryAcquire(events))) + .map(Boolean.FALSE::equals); } @Override public Mono decrement(@NonNull String apiKey, @NonNull String bucketName, long events) { - Mono eval = redisClient.getScript(StringCodec.INSTANCE).eval( - RScript.Mode.READ_WRITE, - LUA_SCRIPT_DECR, - RScript.ReturnType.VALUE, - List.of(bucketName + ":" + apiKey), - events); + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.tryAcquire(-events).then(); + } + + @Override + public Mono availableEvents(@NonNull String apiKey, @NonNull String bucketName) { + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.availablePermits(); + } - return eval.map("OK"::equals) - .switchIfEmpty(Mono.error(new IllegalStateException("Rate limit bucket not found"))) - .then(); + @Override + public Mono getRemainingTTL(@NonNull String apiKey, @NonNull String bucketName) { + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.remainTimeToLive(); } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java index e6f90f921d..aa10a8b2bf 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -31,7 +31,6 @@ import jakarta.ws.rs.core.Response; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -48,6 +47,7 @@ import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; +import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Set; @@ -64,6 +64,8 @@ import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; import static java.util.stream.Collectors.counting; import static java.util.stream.Collectors.groupingBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; @Testcontainers(parallel = true) @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -156,8 +158,8 @@ void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { Map responseMap = triggerCallsWithApiKey(LIMIT * 2, projectName, apiKey, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(429)); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(429)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -169,12 +171,12 @@ void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { .get()) { // Verify that traces created are equal to the limit - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT, page.content().size()); - Assertions.assertEquals(LIMIT, page.total()); - Assertions.assertEquals(LIMIT, page.size()); + assertEquals(LIMIT, page.content().size()); + assertEquals(LIMIT, page.total()); + assertEquals(LIMIT, page.size()); } } @@ -194,13 +196,13 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC Map responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -211,12 +213,12 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC .header(WORKSPACE_HEADER, workspaceName) .get()) { - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT * 2, page.content().size()); - Assertions.assertEquals(LIMIT * 2, page.total()); - Assertions.assertEquals(LIMIT * 2, page.size()); + assertEquals(LIMIT * 2, page.content().size()); + assertEquals(LIMIT * 2, page.total()); + assertEquals(LIMIT * 2, page.size()); } } @@ -236,8 +238,8 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal Map responseMap = triggerCallsWithCookie(LIMIT * 2, projectName, sessionToken, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(429)); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(429)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -248,12 +250,12 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal .header(WORKSPACE_HEADER, workspaceName) .get()) { - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT, page.content().size()); - Assertions.assertEquals(LIMIT, page.total()); - Assertions.assertEquals(LIMIT, page.size()); + assertEquals(LIMIT, page.content().size()); + assertEquals(LIMIT, page.total()); + assertEquals(LIMIT, page.size()); } } @@ -273,13 +275,13 @@ void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAll Map responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -291,18 +293,18 @@ void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAll .get()) { // Verify that traces created are equal to the limit - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT * 2, page.content().size()); - Assertions.assertEquals(LIMIT * 2, page.total()); - Assertions.assertEquals(LIMIT * 2, page.size()); + assertEquals(LIMIT * 2, page.content().size()); + assertEquals(LIMIT * 2, page.total()); + assertEquals(LIMIT * 2, page.size()); } } @Test - @DisplayName("Rate limit: When remaining limit is less than the batch size; Then reject the request") + @DisplayName("Rate limit: When remaining limit is less than the batch size, Then reject the request") void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest() { String apiKey = UUID.randomUUID().toString(); @@ -316,7 +318,7 @@ void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest( Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); - Assertions.assertEquals(1, responseMap.get(201)); + assertEquals(1, responseMap.get(201)); List traces = IntStream.range(0, (int) LIMIT) .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() @@ -333,14 +335,14 @@ void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest( .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(new TraceBatch(traces)))) { - Assertions.assertEquals(429, response.getStatus()); + assertEquals(429, response.getStatus()); var error = response.readEntity(ErrorMessage.class); - Assertions.assertEquals("Too Many Requests", error.getMessage()); + assertEquals("Too Many Requests", error.getMessage()); } } @Test - @DisplayName("Rate limit: When after reject request due to batch size; Then accept the request with remaining limit") + @DisplayName("Rate limit: When after reject request due to batch size, Then accept the request with remaining limit") void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRemainingLimit() { String apiKey = UUID.randomUUID().toString(); @@ -354,7 +356,7 @@ void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRe Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); - Assertions.assertEquals(1, responseMap.get(201)); + assertEquals(1, responseMap.get(201)); List traces = IntStream.range(0, (int) LIMIT) .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() @@ -371,9 +373,9 @@ void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRe .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(new TraceBatch(traces)))) { - Assertions.assertEquals(429, response.getStatus()); + assertEquals(429, response.getStatus()); var error = response.readEntity(ErrorMessage.class); - Assertions.assertEquals("Too Many Requests", error.getMessage()); + assertEquals("Too Many Requests", error.getMessage()); } try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -384,13 +386,13 @@ void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRe .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(new TraceBatch(traces.subList(0, (int) LIMIT - 1))))) { - Assertions.assertEquals(204, response.getStatus()); + assertEquals(204, response.getStatus()); } } @ParameterizedTest @MethodSource - @DisplayName("Rate limit: When batch endpoint consumer remaining limit; Then reject next request") + @DisplayName("Rate limit: When batch endpoint consumer remaining limit, Then reject next request") void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( Object batch, Object batch2, @@ -412,17 +414,114 @@ void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( try (var response = request.method(method, Entity.json(batch))) { - Assertions.assertEquals(204, response.getStatus()); + assertEquals(204, response.getStatus()); } try (var response = request.method(method, Entity.json(batch2))) { - Assertions.assertEquals(429, response.getStatus()); + assertEquals(429, response.getStatus()); var error = response.readEntity(ErrorMessage.class); - Assertions.assertEquals("Too Many Requests", error.getMessage()); + assertEquals("Too Many Requests", error.getMessage()); } } + @Test + @DisplayName("Rate limit: When operation fails after accepting request; Then decrement the limit") + void rateLimit__whenOperationFailsAfterAcceptingRequest__thenDecrementTheLimit() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + // consume 1 from the limit + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + assertEquals(201, response.getStatus()); + } + + // consumer limit - 2 from the limit leaving 1 remaining + Map responseMap = triggerCallsWithApiKey(LIMIT - 2, projectName, apiKey, workspaceName); + + assertEquals(LIMIT - 2, responseMap.get(201)); + + // consume the remaining limit but fail + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + assertEquals(409, response.getStatus()); + } + + // consume the remaining limit + responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + assertEquals(1, responseMap.get(201)); + + // verify that the limit is now 0 + responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + assertEquals(1, responseMap.get(429)); + } + + @Test + @DisplayName("Rate limit: When processing operations, Then return remaining limit as header") + void rateLimit__whenProcessingOperations__thenReturnRemainingLimitAsHeader() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + IntStream.range(0, (int) LIMIT + 1).forEach(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + if (i < LIMIT) { + assertEquals(201, response.getStatus()); + + String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT); + String userLimit = response.getHeaderString(RequestContext.USER_LIMIT); + String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL); + + assertEquals(LIMIT - i - 1, Long.parseLong(remainingLimit)); + assertEquals(RateLimited.GENERAL_EVENTS, userLimit); + assertThat(Long.parseLong(remainingTtl)).isStrictlyBetween(0L, Duration.ofSeconds(LIMIT_DURATION_IN_SECONDS).toMillis()); + } else { + assertEquals(429, response.getStatus()); + } + } + }); + } + public Stream rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest() { var projectName = UUID.randomUUID().toString();