Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora committed Sep 20, 2024
1 parent 27c1f4b commit 004a811
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,27 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
}

RateLimited rateLimit = method.getAnnotation(RateLimited.class);
String bucket = rateLimit.value();

// Check if the bucket is the general events bucket
LimitConfig generalLimit = Optional.ofNullable(rateLimitConfig.getCustomLimits())
.map(limits -> limits.get(bucket))
// Check events bucket
Optional<LimitConfig> limitConfig = Optional.ofNullable(rateLimitConfig.getCustomLimits())
.map(limits -> limits.get(rateLimit.value()));

String limitBucket = limitConfig.isPresent() ? rateLimit.value() : RateLimited.GENERAL_EVENTS;

LimitConfig generalLimit = limitConfig
.orElse(rateLimitConfig.getGeneralLimit());

String apiKey = requestContext.get().getApiKey();
Object body = getParameters(invocation);

long events = body instanceof RateEventContainer container ? container.eventCount() : 1;

verifyRateLimit(events, apiKey, bucket, generalLimit);
verifyRateLimit(events, apiKey, limitBucket, generalLimit);

try {
return invocation.proceed();
} finally {
setLimitHeaders(apiKey, bucket);
setLimitHeaders(apiKey, limitBucket);
}
}

Expand All @@ -77,8 +80,10 @@ private void verifyRateLimit(long events, String apiKey, String bucket, LimitCon

private void setLimitHeaders(String apiKey, String bucket) {
requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(bucket));
requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block()));
requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block()));
requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL,
List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block()));
requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT,
List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block()));
}

private Object getParameters(MethodInvocation method) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
public class RateLimitResponseFilter implements ContainerResponseFilter {

@Override
public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) throws IOException {
public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext)
throws IOException {
List<Object> userLimit = getValueFromHeader(requestContext, RequestContext.USER_LIMIT);
List<Object> remainingLimit = getValueFromHeader(requestContext, RequestContext.USER_REMAINING_LIMIT);
List<Object> remainingTtl = getValueFromHeader(requestContext, RequestContext.USER_LIMIT_REMAINING_TTL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ public record AppContextConfig(
Long limit,
Long limitDurationInSeconds,
Map<String, LimitConfig> customLimits,
List<Object> customBeans
) {
List<Object> customBeans) {
}

public static TestDropwizardAppExtension newTestDropwizardAppExtension(String jdbcUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,8 @@ void rateLimit__whenProcessingOperations__thenReturnRemainingLimitAsHeader() {
if (i < LIMIT) {
assertEquals(HttpStatus.SC_CREATED, response.getStatus());

String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT);
String userLimit = response.getHeaderString(RequestContext.USER_LIMIT);
String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL);

assertEquals(LIMIT - i - 1, Long.parseLong(remainingLimit));
assertEquals(RateLimited.GENERAL_EVENTS, userLimit);
assertThat(Long.parseLong(remainingTtl)).isStrictlyBetween(0L, Duration.ofSeconds(LIMIT_DURATION_IN_SECONDS).toMillis());
assertLimitHeaders(response, LIMIT - i - 1, RateLimited.GENERAL_EVENTS,
(int) LIMIT_DURATION_IN_SECONDS);
} else {
assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus());
}
Expand Down Expand Up @@ -568,6 +563,8 @@ void rateLimit__whenCustomRatedBeanMethodIsCalled__thenRateLimitIsApplied() {
.post(Entity.json(""))) {

assertEquals(HttpStatus.SC_CREATED, response.getStatus());

assertLimitHeaders(response, 0, CUSTOM_LIMIT, 1);
}

try (var response = client.target("%s/v1/private/test".formatted(baseURI))
Expand All @@ -578,10 +575,21 @@ void rateLimit__whenCustomRatedBeanMethodIsCalled__thenRateLimitIsApplied() {
.post(Entity.json(""))) {

assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus());

assertLimitHeaders(response, 0, CUSTOM_LIMIT, 1);
}

}

private static void assertLimitHeaders(Response response, long expected, String limitBucket, int limitDuration) {
String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT);
String userLimit = response.getHeaderString(RequestContext.USER_LIMIT);
String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL);

assertEquals(expected, Long.parseLong(remainingLimit));
assertEquals(limitBucket, userLimit);
assertThat(Long.parseLong(remainingTtl)).isBetween(0L, Duration.ofSeconds(limitDuration).toMillis());
}

private Map<Integer, Long> triggerCallsWithCookie(long limit, String projectName, String sessionToken,
String workspaceName) {
Expand Down

0 comments on commit 004a811

Please sign in to comment.