Skip to content

Commit

Permalink
Add new rate limit headers
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora committed Sep 19, 2024
1 parent e7718ac commit f9a537b
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 110 deletions.
2 changes: 1 addition & 1 deletion apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ rateLimit:
enabled: ${RATE_LIMIT_ENABLED:-false}
generalEvents:
limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-5000}
durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION:-1}
durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION_SEC:-1}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
public class AuthFilter implements ContainerRequestFilter {

private final AuthService authService;
private final jakarta.inject.Provider<RequestContext> requestContext;

@Override
public void filter(ContainerRequestContext context) throws IOException {
Expand All @@ -36,7 +37,7 @@ public void filter(ContainerRequestContext context) throws IOException {
if (Pattern.matches("/v1/private/.*", requestUri.getPath())) {
authService.authenticate(headers, sessionToken);
}

requestContext.get().setHeaders(context.getHeaders());
}

HttpHeaders getHttpHeaders(ContainerRequestContext context) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.infrastructure.auth;

import com.google.inject.servlet.RequestScoped;
import jakarta.ws.rs.core.MultivaluedMap;
import lombok.Data;

@RequestScoped
Expand All @@ -13,9 +14,13 @@ public class RequestContext {
public static final String SESSION_COOKIE = "sessionToken";
public static final String WORKSPACE_ID = "workspaceId";
public static final String API_KEY = "apiKey";
public static final String USER_LIMIT = "Opik-User-Limit";
public static final String USER_REMAINING_LIMIT = "Opik-User-Remaining-Limit";
public static final String USER_LIMIT_REMAINING_TTL = "Opik-User-Remaining-Limit-TTL-Millis";

private String userName;
private String workspaceName;
private String workspaceId;
private String apiKey;
private MultivaluedMap<String, String> headers;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import lombok.extern.slf4j.Slf4j;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import reactor.core.publisher.Mono;

import java.lang.reflect.Method;
import java.util.List;

@Slf4j
@RequiredArgsConstructor
Expand Down Expand Up @@ -55,6 +55,7 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
.block();

if (Boolean.TRUE.equals(limitExceeded)) {
setLimitHeaders(apiKey, bucket);
throw new ClientErrorException("Too Many Requests", 429);
}

Expand All @@ -63,12 +64,20 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
} catch (Exception ex) {
decreaseLimitInCaseOfError(bucket, events);
throw ex;
} finally {
setLimitHeaders(apiKey, bucket);
}
}

return invocation.proceed();
}

private void setLimitHeaders(String apiKey, String bucket) {
requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(RateLimited.GENERAL_EVENTS));
requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block()));
requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block()));
}

private Object getParameters(MethodInvocation method) {

for (int i = 0; i < method.getArguments().length; i++) {
Expand All @@ -82,11 +91,9 @@ private Object getParameters(MethodInvocation method) {

private void decreaseLimitInCaseOfError(String bucket, Long events) {
try {
Mono.deferContextual(context -> {
String apiKey = context.get(RequestContext.API_KEY);

return rateLimitService.get().decrement(apiKey, bucket, events);
}).subscribe();
String apiKey = requestContext.get().getApiKey();
rateLimitService.get().decrement(apiKey, bucket, events)
.subscribe();
} catch (Exception ex) {
log.warn("Failed to decrement rate limit", ex);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.comet.opik.infrastructure.ratelimit;

import com.comet.opik.infrastructure.auth.RequestContext;
import jakarta.inject.Inject;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerResponseContext;
import jakarta.ws.rs.container.ContainerResponseFilter;
import jakarta.ws.rs.ext.Provider;
import lombok.RequiredArgsConstructor;

import java.io.IOException;
import java.util.List;

@Provider
@RequiredArgsConstructor(onConstructor_ = @Inject)
public class RateLimitResponseFilter implements ContainerResponseFilter {

@Override
public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) throws IOException {
List<Object> userLimit = getValueFromHeader(requestContext, RequestContext.USER_LIMIT);
List<Object> remainingLimit = getValueFromHeader(requestContext, RequestContext.USER_REMAINING_LIMIT);
List<Object> remainingTtl = getValueFromHeader(requestContext, RequestContext.USER_LIMIT_REMAINING_TTL);

responseContext.getHeaders().put(RequestContext.USER_LIMIT, userLimit);
responseContext.getHeaders().put(RequestContext.USER_REMAINING_LIMIT, remainingLimit);
responseContext.getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, remainingTtl);
}

private List<Object> getValueFromHeader(ContainerRequestContext requestContext, String key) {
return requestContext.getHeaders().getOrDefault(key, List.of())
.stream()
.map(Object.class::cast)
.toList();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ Mono<Boolean> isLimitExceeded(String apiKey, long events, String bucketName, lon
long limitDurationInSeconds);

Mono<Void> decrement(String apiKey, String bucketName, long events);

Mono<Long> availableEvents(String apiKey, String bucketName);

Mono<Long> getRemainingTTL(String apiKey, String bucket);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,17 @@

import com.comet.opik.infrastructure.ratelimit.RateLimitService;
import lombok.NonNull;
import org.redisson.api.RScript;
import org.redisson.api.RRateLimiterReactive;
import org.redisson.api.RateIntervalUnit;
import org.redisson.api.RateType;
import org.redisson.api.RedissonReactiveClient;
import org.redisson.client.codec.StringCodec;
import reactor.core.publisher.Mono;

import java.util.List;
import java.time.Duration;

public class RedisRateLimitService implements RateLimitService {

private static final String LUA_SCRIPT_ADD = """
local current = redis.call('GET', KEYS[1])
if not current then
current = 0
else
current = tonumber(current)
end
local limit = tonumber(ARGV[1])
local increment = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
if (current + increment) > limit then
return 0 -- Failure, limit exceeded
else
redis.call('INCRBY', KEYS[1], increment)
if redis.call('TTL', KEYS[1]) == -1 then
redis.call('EXPIRE', KEYS[1], ttl, 'NX')
end
return 1 -- Success, increment done
end
""";

private static final String LUA_SCRIPT_DECR = """
local current = redis.call('GET', KEYS[1])
if not current then
current = 0
else
current = tonumber(current)
end
local decrement = tonumber(ARGV[1])
if (current - decrement) > 0 then
redis.call('DECRBY', KEYS[1], decrement)
else
redis.call('SET', KEYS[1], 0)
end
return 'OK'
""";
private static final String KEY = "%s:%s";

private final RedissonReactiveClient redisClient;

Expand All @@ -65,30 +24,30 @@ public RedisRateLimitService(RedissonReactiveClient redisClient) {
public Mono<Boolean> isLimitExceeded(String apiKey, long events, String bucketName, long limit,
long limitDurationInSeconds) {

Mono<Long> eval = redisClient.getScript(StringCodec.INSTANCE).eval(
RScript.Mode.READ_WRITE,
LUA_SCRIPT_ADD,
RScript.ReturnType.INTEGER,
List.of(bucketName + ":" + apiKey),
limit,
events,
limitDurationInSeconds);
RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey));

return eval.map(result -> result == 0);
return rateLimit.trySetRate(RateType.OVERALL, limit, limitDurationInSeconds, RateIntervalUnit.SECONDS)
.then(Mono.defer(() -> rateLimit.expireIfNotSet(Duration.ofSeconds(limitDurationInSeconds))))
.then(Mono.defer(() -> rateLimit.tryAcquire(events)))
.map(Boolean.FALSE::equals);
}

@Override
public Mono<Void> decrement(@NonNull String apiKey, @NonNull String bucketName, long events) {
Mono<String> eval = redisClient.getScript(StringCodec.INSTANCE).eval(
RScript.Mode.READ_WRITE,
LUA_SCRIPT_DECR,
RScript.ReturnType.VALUE,
List.of(bucketName + ":" + apiKey),
events);
RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey));
return rateLimit.tryAcquire(-events).then();
}

@Override
public Mono<Long> availableEvents(@NonNull String apiKey, @NonNull String bucketName) {
RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey));
return rateLimit.availablePermits();
}

return eval.map("OK"::equals)
.switchIfEmpty(Mono.error(new IllegalStateException("Rate limit bucket not found")))
.then();
@Override
public Mono<Long> getRemainingTTL(@NonNull String apiKey, @NonNull String bucketName) {
RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey));
return rateLimit.remainTimeToLive();
}

}
Loading

0 comments on commit f9a537b

Please sign in to comment.