Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-68] Add rate limit #261

Merged
merged 14 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ server:
enableVirtualThreads: ${ENABLE_VIRTUAL_THREADS:-false}
gzip:
enabled: true

rateLimit:
enabled: ${RATE_LIMIT_ENABLED:-false}
generalEvents:
limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-5000}
durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION:-1}
Nimrod007 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.infrastructure.bundle.LiquibaseBundle;
import com.comet.opik.infrastructure.db.DatabaseAnalyticsModule;
import com.comet.opik.infrastructure.db.IdGeneratorModule;
import com.comet.opik.infrastructure.ratelimit.RateLimitModule;
import com.comet.opik.infrastructure.redis.RedisModule;
import com.comet.opik.utils.JsonBigDecimalDeserializer;
import com.fasterxml.jackson.annotation.JsonInclude;
Expand Down Expand Up @@ -58,7 +59,7 @@ public void initialize(Bootstrap<OpikConfiguration> bootstrap) {
bootstrap.addBundle(GuiceBundle.builder()
.bundles(JdbiBundle.<OpikConfiguration>forDatabase((conf, env) -> conf.getDatabase())
.withPlugins(new SqlObjectPlugin(), new Jackson2Plugin()))
.modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule())
.modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), new RateLimitModule())
.enableAutoConfig()
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.comet.opik.domain.FeedbackScoreService;
import com.comet.opik.domain.TraceService;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.ratelimit.RateLimited;
import com.comet.opik.utils.AsyncUtils;
import com.fasterxml.jackson.annotation.JsonView;
import io.swagger.v3.oas.annotations.Operation;
Expand Down Expand Up @@ -125,6 +126,7 @@ public Response getById(@PathParam("id") UUID id) {
@Operation(operationId = "createTrace", summary = "Create trace", description = "Get trace", responses = {
@ApiResponse(responseCode = "201", description = "Created", headers = {
@Header(name = "Location", required = true, example = "${basePath}/v1/private/traces/{traceId}", schema = @Schema(implementation = String.class))})})
@RateLimited
public Response create(
@RequestBody(content = @Content(schema = @Schema(implementation = Trace.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid Trace trace,
@Context UriInfo uriInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ public class OpikConfiguration extends Configuration {
@Valid
@NotNull @JsonProperty
private DistributedLockConfig distributedLock = new DistributedLockConfig();

@Valid
@NotNull @JsonProperty
private RateLimitConfig rateLimit = new RateLimitConfig();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.comet.opik.infrastructure;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Positive;
import jakarta.validation.constraints.PositiveOrZero;
import lombok.Data;

@Data
public class RateLimitConfig {

public record LimitConfig(@Valid @JsonProperty @PositiveOrZero long limit, @Valid @JsonProperty @Positive long durationInSeconds) {
}

@Valid
@JsonProperty
private boolean enabled;

@Valid
@JsonProperty
private LimitConfig generalEvents;

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public void authenticate(HttpHeaders headers, Cookie sessionToken) {
requestContext.get().setWorkspaceName(currentWorkspaceName);
requestContext.get().setUserName(ProjectService.DEFAULT_USER);
requestContext.get().setWorkspaceId(ProjectService.DEFAULT_WORKSPACE_ID);
requestContext.get().setApiKey("default");
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ private void authenticateUsingSessionToken(Cookie sessionToken, String workspace
AuthResponse credentials = verifyResponse(response);

setCredentialIntoContext(credentials.user(), credentials.workspaceId());
requestContext.get().setApiKey(sessionToken.getValue());
}
}

Expand Down Expand Up @@ -108,6 +109,7 @@ private void authenticateUsingApiKey(HttpHeaders headers, String workspaceName)
}

setCredentialIntoContext(credentials.userName(), credentials.workspaceId());
requestContext.get().setApiKey(apiKey);
}

private ValidatedAuthCredentials validateApiKeyAndGetCredentials(String workspaceName, String apiKey) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,21 @@
package com.comet.opik.infrastructure.auth;

import com.google.inject.servlet.RequestScoped;
import lombok.Data;

@RequestScoped
@Data
public class RequestContext {

public static final String WORKSPACE_HEADER = "Comet-Workspace";
public static final String USER_NAME = "userName";
public static final String WORKSPACE_NAME = "workspaceName";
public static final String SESSION_COOKIE = "sessionToken";
public static final String WORKSPACE_ID = "workspaceId";
public static final String API_KEY = "apiKey";

private String userName;
private String workspaceName;
private String workspaceId;

public final String getUserName() {
return userName;
}

public final String getWorkspaceName() {
return workspaceName;
}

public final String getWorkspaceId() {
return workspaceId;
}

void setUserName(String workspaceName) {
this.userName = workspaceName;
}

void setWorkspaceName(String workspaceName) {
this.workspaceName = workspaceName;
}

public void setWorkspaceId(String workspaceId) {
this.workspaceId = workspaceId;
}
private String apiKey;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.comet.opik.infrastructure.ratelimit;

import com.comet.opik.infrastructure.RateLimitConfig;
import com.comet.opik.infrastructure.auth.RequestContext;
import jakarta.inject.Inject;
import jakarta.ws.rs.ClientErrorException;
import lombok.RequiredArgsConstructor;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import reactor.core.publisher.Mono;

import java.lang.reflect.Method;


@RequiredArgsConstructor(onConstructor_ = @Inject)
class RateLimitInterceptor implements MethodInterceptor {

private final RateLimitService rateLimitService;
private final RateLimitConfig rateLimitConfig;

@Override
public Object invoke(MethodInvocation invocation) throws Throwable {

// Get the method being invoked
Method method = invocation.getMethod();

// Check if the method is annotated with @RateLimit
if (!method.isAnnotationPresent(RateLimited.class)) {
return invocation.proceed();
}

RateLimited rateLimit = method.getAnnotation(RateLimited.class);
String bucket = rateLimit.value();

if (!rateLimitConfig.isEnabled()) {
return invocation.proceed();
}
andrescrz marked this conversation as resolved.
Show resolved Hide resolved

// Check if the bucket is the general events bucket
if (bucket.equals(RateLimited.GENERAL_EVENTS)) {
andrescrz marked this conversation as resolved.
Show resolved Hide resolved

long limit = rateLimitConfig.getGeneralEvents().limit();
long limitDurationInSeconds = rateLimitConfig.getGeneralEvents().durationInSeconds();
andrescrz marked this conversation as resolved.
Show resolved Hide resolved

Boolean limitExceeded = Mono.deferContextual(context -> {
String apiKey = context.get(RequestContext.API_KEY);

// Check if the rate limit is exceeded
return rateLimitService.isLimitExceeded(apiKey, bucket, limit, limitDurationInSeconds);
}).block();


if (Boolean.TRUE.equals(limitExceeded)) {
throw new ClientErrorException(429);
}
}

return invocation.proceed();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.comet.opik.infrastructure.ratelimit;

import com.comet.opik.infrastructure.OpikConfiguration;
import com.comet.opik.infrastructure.RateLimitConfig;
import com.google.inject.matcher.Matchers;
import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule;

public class RateLimitModule extends DropwizardAwareModule<OpikConfiguration> {

@Override
protected void configure() {

var rateLimit = configuration(RateLimitService.class);
var config = configuration(RateLimitConfig.class);
var rateLimitInterceptor = new RateLimitInterceptor(rateLimit, config);

bindInterceptor(Matchers.any(), Matchers.annotatedWith(RateLimited.class), rateLimitInterceptor);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.comet.opik.infrastructure.ratelimit;

import reactor.core.publisher.Mono;

public interface RateLimitService {

Mono<Boolean> isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.comet.opik.infrastructure.ratelimit;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimited {

String GENERAL_EVENTS = "general_events";

String value() default GENERAL_EVENTS; // bucket capacity
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.comet.opik.infrastructure.DistributedLockConfig;
import com.comet.opik.infrastructure.OpikConfiguration;
import com.comet.opik.infrastructure.RedisConfig;
import com.comet.opik.infrastructure.ratelimit.RateLimitService;
import com.google.inject.Provides;
import jakarta.inject.Singleton;
import org.redisson.Redisson;
Expand All @@ -25,4 +26,10 @@ public LockService lockService(RedissonReactiveClient redisClient,
return new RedissonLockService(redisClient, distributedLockConfig);
}

@Provides
@Singleton
public RateLimitService rateLimitService(RedissonReactiveClient redisClient) {
return new RedisRateLimitService(redisClient);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.comet.opik.infrastructure.redis;

import com.comet.opik.infrastructure.ratelimit.RateLimitService;
import org.redisson.api.RAtomicLongReactive;
import org.redisson.api.RedissonReactiveClient;
import reactor.core.publisher.Mono;

import java.time.Duration;

public class RedisRateLimitService implements RateLimitService {

private final RedissonReactiveClient redisClient;

public RedisRateLimitService(RedissonReactiveClient redisClient) {
this.redisClient = redisClient;
}
andrescrz marked this conversation as resolved.
Show resolved Hide resolved

@Override
public Mono<Boolean> isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds) {

RAtomicLongReactive limitInstance = redisClient.getAtomicLong(bucketName + ":" + apiKey);

return limitInstance
.incrementAndGet()
.flatMap(count -> {

if (count == 1) {
return limitInstance.expire(Duration.ofSeconds(limitDurationInSeconds))
.map(__ -> count > limit);
}

return Mono.just(count > limit);
});
}
}
3 changes: 3 additions & 0 deletions apps/opik-backend/src/test/resources/config-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ server:
enableVirtualThreads: ${ENABLE_VIRTUAL_THREADS:-false}
gzip:
enabled: true

rateLimit:
enabled: false
andrescrz marked this conversation as resolved.
Show resolved Hide resolved