From 8a16169a3213a41e62f8e84edbc78c203b9e444a Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Tue, 17 Sep 2024 10:57:10 +0200 Subject: [PATCH 1/6] [OPIK-68] Add rate limit --- apps/opik-backend/config.yml | 6 ++ .../java/com/comet/opik/OpikApplication.java | 3 +- .../api/resources/v1/priv/TracesResource.java | 2 + .../infrastructure/OpikConfiguration.java | 4 ++ .../opik/infrastructure/RateLimitConfig.java | 23 +++++++ .../opik/infrastructure/auth/AuthService.java | 1 + .../auth/RemoteAuthService.java | 2 + .../infrastructure/auth/RequestContext.java | 28 ++------- .../ratelimit/RateLimitInterceptor.java | 61 +++++++++++++++++++ .../ratelimit/RateLimitModule.java | 20 ++++++ .../ratelimit/RateLimitService.java | 8 +++ .../infrastructure/ratelimit/RateLimited.java | 15 +++++ .../infrastructure/redis/RedisModule.java | 7 +++ .../redis/RedisRateLimitService.java | 35 +++++++++++ .../src/test/resources/config-test.yml | 3 + 15 files changed, 193 insertions(+), 25 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index ec31499b02..d70b40d0a1 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -65,3 +65,9 @@ server: enableVirtualThreads: ${ENABLE_VIRTUAL_THREADS:-false} gzip: enabled: true + +rateLimit: + enabled: ${RATE_LIMIT_ENABLED:-false} + generalEvents: + limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-5000} + durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION:-1} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java index 40248059ef..da2f63228e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java @@ -5,6 +5,7 @@ import com.comet.opik.infrastructure.bundle.LiquibaseBundle; import com.comet.opik.infrastructure.db.DatabaseAnalyticsModule; import com.comet.opik.infrastructure.db.IdGeneratorModule; +import com.comet.opik.infrastructure.ratelimit.RateLimitModule; import com.comet.opik.infrastructure.redis.RedisModule; import com.comet.opik.utils.JsonBigDecimalDeserializer; import com.fasterxml.jackson.annotation.JsonInclude; @@ -58,7 +59,7 @@ public void initialize(Bootstrap bootstrap) { bootstrap.addBundle(GuiceBundle.builder() .bundles(JdbiBundle.forDatabase((conf, env) -> conf.getDatabase()) .withPlugins(new SqlObjectPlugin(), new Jackson2Plugin())) - .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule()) + .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), new RateLimitModule()) .enableAutoConfig() .build()); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index dc91804022..066c35119f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -15,6 +15,7 @@ import com.comet.opik.domain.FeedbackScoreService; import com.comet.opik.domain.TraceService; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; @@ -125,6 +126,7 @@ public Response getById(@PathParam("id") UUID id) { @Operation(operationId = "createTrace", summary = "Create trace", description = "Get trace", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/traces/{traceId}", schema = @Schema(implementation = String.class))})}) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Trace.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid Trace trace, @Context UriInfo uriInfo) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java index 754a685ade..8bdef4b1e3 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java @@ -33,4 +33,8 @@ public class OpikConfiguration extends Configuration { @Valid @NotNull @JsonProperty private DistributedLockConfig distributedLock = new DistributedLockConfig(); + + @Valid + @NotNull @JsonProperty + private RateLimitConfig rateLimit = new RateLimitConfig(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java new file mode 100644 index 0000000000..0445a387ab --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java @@ -0,0 +1,23 @@ +package com.comet.opik.infrastructure; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Positive; +import jakarta.validation.constraints.PositiveOrZero; +import lombok.Data; + +@Data +public class RateLimitConfig { + + public record LimitConfig(@Valid @JsonProperty @PositiveOrZero long limit, @Valid @JsonProperty @Positive long durationInSeconds) { + } + + @Valid + @JsonProperty + private boolean enabled; + + @Valid + @JsonProperty + private LimitConfig generalEvents; + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java index 26fdb91625..fa8516355f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java @@ -31,6 +31,7 @@ public void authenticate(HttpHeaders headers, Cookie sessionToken) { requestContext.get().setWorkspaceName(currentWorkspaceName); requestContext.get().setUserName(ProjectService.DEFAULT_USER); requestContext.get().setWorkspaceId(ProjectService.DEFAULT_WORKSPACE_ID); + requestContext.get().setApiKey("default"); return; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java index f59a328fd8..635ac00aba 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java @@ -81,6 +81,7 @@ private void authenticateUsingSessionToken(Cookie sessionToken, String workspace AuthResponse credentials = verifyResponse(response); setCredentialIntoContext(credentials.user(), credentials.workspaceId()); + requestContext.get().setApiKey(sessionToken.getValue()); } } @@ -108,6 +109,7 @@ private void authenticateUsingApiKey(HttpHeaders headers, String workspaceName) } setCredentialIntoContext(credentials.userName(), credentials.workspaceId()); + requestContext.get().setApiKey(apiKey); } private ValidatedAuthCredentials validateApiKeyAndGetCredentials(String workspaceName, String apiKey) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java index b70a0191e7..64f964e827 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java @@ -1,8 +1,10 @@ package com.comet.opik.infrastructure.auth; import com.google.inject.servlet.RequestScoped; +import lombok.Data; @RequestScoped +@Data public class RequestContext { public static final String WORKSPACE_HEADER = "Comet-Workspace"; @@ -10,32 +12,10 @@ public class RequestContext { public static final String WORKSPACE_NAME = "workspaceName"; public static final String SESSION_COOKIE = "sessionToken"; public static final String WORKSPACE_ID = "workspaceId"; + public static final String API_KEY = "apiKey"; private String userName; private String workspaceName; private String workspaceId; - - public final String getUserName() { - return userName; - } - - public final String getWorkspaceName() { - return workspaceName; - } - - public final String getWorkspaceId() { - return workspaceId; - } - - void setUserName(String workspaceName) { - this.userName = workspaceName; - } - - void setWorkspaceName(String workspaceName) { - this.workspaceName = workspaceName; - } - - public void setWorkspaceId(String workspaceId) { - this.workspaceId = workspaceId; - } + private String apiKey; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java new file mode 100644 index 0000000000..e421dc9e5d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -0,0 +1,61 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.RateLimitConfig; +import com.comet.opik.infrastructure.auth.RequestContext; +import jakarta.inject.Inject; +import jakarta.ws.rs.ClientErrorException; +import lombok.RequiredArgsConstructor; +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; +import reactor.core.publisher.Mono; + +import java.lang.reflect.Method; + + +@RequiredArgsConstructor(onConstructor_ = @Inject) +class RateLimitInterceptor implements MethodInterceptor { + + private final RateLimitService rateLimitService; + private final RateLimitConfig rateLimitConfig; + + @Override + public Object invoke(MethodInvocation invocation) throws Throwable { + + // Get the method being invoked + Method method = invocation.getMethod(); + + // Check if the method is annotated with @RateLimit + if (!method.isAnnotationPresent(RateLimited.class)) { + return invocation.proceed(); + } + + RateLimited rateLimit = method.getAnnotation(RateLimited.class); + String bucket = rateLimit.value(); + + if (!rateLimitConfig.isEnabled()) { + return invocation.proceed(); + } + + // Check if the bucket is the general events bucket + if (bucket.equals(RateLimited.GENERAL_EVENTS)) { + + long limit = rateLimitConfig.getGeneralEvents().limit(); + long limitDurationInSeconds = rateLimitConfig.getGeneralEvents().durationInSeconds(); + + Boolean limitExceeded = Mono.deferContextual(context -> { + String apiKey = context.get(RequestContext.API_KEY); + + // Check if the rate limit is exceeded + return rateLimitService.isLimitExceeded(apiKey, bucket, limit, limitDurationInSeconds); + }).block(); + + + if (Boolean.TRUE.equals(limitExceeded)) { + throw new ClientErrorException(429); + } + } + + return invocation.proceed(); + } + +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java new file mode 100644 index 0000000000..2414fd0f67 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java @@ -0,0 +1,20 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.OpikConfiguration; +import com.comet.opik.infrastructure.RateLimitConfig; +import com.google.inject.matcher.Matchers; +import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule; + +public class RateLimitModule extends DropwizardAwareModule { + + @Override + protected void configure() { + + var rateLimit = configuration(RateLimitService.class); + var config = configuration(RateLimitConfig.class); + var rateLimitInterceptor = new RateLimitInterceptor(rateLimit, config); + + bindInterceptor(Matchers.any(), Matchers.annotatedWith(RateLimited.class), rateLimitInterceptor); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java new file mode 100644 index 0000000000..62a1aed5ed --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -0,0 +1,8 @@ +package com.comet.opik.infrastructure.ratelimit; + +import reactor.core.publisher.Mono; + +public interface RateLimitService { + + Mono isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds); +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java new file mode 100644 index 0000000000..c9a440cf98 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java @@ -0,0 +1,15 @@ +package com.comet.opik.infrastructure.ratelimit; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface RateLimited { + + String GENERAL_EVENTS = "general_events"; + + String value() default GENERAL_EVENTS; // bucket capacity +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java index 69f8f0e4da..3def3365e0 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java @@ -3,6 +3,7 @@ import com.comet.opik.infrastructure.DistributedLockConfig; import com.comet.opik.infrastructure.OpikConfiguration; import com.comet.opik.infrastructure.RedisConfig; +import com.comet.opik.infrastructure.ratelimit.RateLimitService; import com.google.inject.Provides; import jakarta.inject.Singleton; import org.redisson.Redisson; @@ -25,4 +26,10 @@ public LockService lockService(RedissonReactiveClient redisClient, return new RedissonLockService(redisClient, distributedLockConfig); } + @Provides + @Singleton + public RateLimitService rateLimitService(RedissonReactiveClient redisClient) { + return new RedisRateLimitService(redisClient); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java new file mode 100644 index 0000000000..c004b0a10c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -0,0 +1,35 @@ +package com.comet.opik.infrastructure.redis; + +import com.comet.opik.infrastructure.ratelimit.RateLimitService; +import org.redisson.api.RAtomicLongReactive; +import org.redisson.api.RedissonReactiveClient; +import reactor.core.publisher.Mono; + +import java.time.Duration; + +public class RedisRateLimitService implements RateLimitService { + + private final RedissonReactiveClient redisClient; + + public RedisRateLimitService(RedissonReactiveClient redisClient) { + this.redisClient = redisClient; + } + + @Override + public Mono isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds) { + + RAtomicLongReactive limitInstance = redisClient.getAtomicLong(bucketName + ":" + apiKey); + + return limitInstance + .incrementAndGet() + .flatMap(count -> { + + if (count == 1) { + return limitInstance.expire(Duration.ofSeconds(limitDurationInSeconds)) + .map(__ -> count > limit); + } + + return Mono.just(count > limit); + }); + } +} diff --git a/apps/opik-backend/src/test/resources/config-test.yml b/apps/opik-backend/src/test/resources/config-test.yml index 412ecf9cd5..b81365d436 100644 --- a/apps/opik-backend/src/test/resources/config-test.yml +++ b/apps/opik-backend/src/test/resources/config-test.yml @@ -65,3 +65,6 @@ server: enableVirtualThreads: ${ENABLE_VIRTUAL_THREADS:-false} gzip: enabled: true + +rateLimit: + enabled: false \ No newline at end of file From e8570d92eb30c1a39a9d01751fe97158d184354d Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Tue, 17 Sep 2024 22:12:22 +0200 Subject: [PATCH 2/6] Add tests --- .../java/com/comet/opik/OpikApplication.java | 3 +- .../resources/v1/priv/DatasetsResource.java | 7 + .../v1/priv/ExperimentsResource.java | 4 + .../v1/priv/FeedbackDefinitionResource.java | 4 + .../resources/v1/priv/ProjectsResource.java | 4 + .../api/resources/v1/priv/SpansResource.java | 8 + .../api/resources/v1/priv/TracesResource.java | 9 +- .../opik/infrastructure/RateLimitConfig.java | 3 +- .../ratelimit/RateLimitInterceptor.java | 43 ++- .../ratelimit/RateLimitModule.java | 7 +- .../ratelimit/RateLimitService.java | 2 + .../infrastructure/ratelimit/RateLimited.java | 4 +- .../redis/RedisRateLimitService.java | 7 + .../TestDropwizardAppExtensionUtils.java | 59 ++- .../v1/priv/DatasetsResourceTest.java | 1 + .../ratelimit/RateLimitE2ETest.java | 336 ++++++++++++++++++ .../ratelimit/RateLimitSetupTest.java | 131 +++++++ 17 files changed, 601 insertions(+), 31 deletions(-) create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java index da2f63228e..3168ce059e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java @@ -59,7 +59,8 @@ public void initialize(Bootstrap bootstrap) { bootstrap.addBundle(GuiceBundle.builder() .bundles(JdbiBundle.forDatabase((conf, env) -> conf.getDatabase()) .withPlugins(new SqlObjectPlugin(), new Jackson2Plugin())) - .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), new RateLimitModule()) + .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), + new RateLimitModule()) .enableAutoConfig() .build()); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index cb54d09b53..779dc7ada9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -16,6 +16,7 @@ import com.comet.opik.domain.FeedbackScoreDAO; import com.comet.opik.domain.IdGenerator; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.annotation.JsonView; @@ -136,6 +137,7 @@ public Response findDatasets( @Header(name = "Location", required = true, example = "${basePath}/api/v1/private/datasets/{id}", schema = @Schema(implementation = String.class)) }) }) + @RateLimited public Response createDataset( @RequestBody(content = @Content(schema = @Schema(implementation = Dataset.class))) @JsonView(Dataset.View.Write.class) @NotNull @Valid Dataset dataset, @Context UriInfo uriInfo) { @@ -156,6 +158,7 @@ public Response createDataset( @Operation(operationId = "updateDataset", summary = "Update dataset by id", description = "Update dataset by id", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response updateDataset(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = DatasetUpdate.class))) @NotNull @Valid DatasetUpdate datasetUpdate) { @@ -172,6 +175,7 @@ public Response updateDataset(@PathParam("id") UUID id, @Operation(operationId = "deleteDataset", summary = "Delete dataset by id", description = "Delete dataset by id", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response deleteDataset(@PathParam("id") UUID id) { String workspaceId = requestContext.get().getWorkspaceId(); @@ -186,6 +190,7 @@ public Response deleteDataset(@PathParam("id") UUID id) { @Operation(operationId = "deleteDatasetByName", summary = "Delete dataset by name", description = "Delete dataset by name", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response deleteDatasetByName( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetIdentifier.class))) @NotNull @Valid DatasetIdentifier identifier) { @@ -346,6 +351,7 @@ private void sendDatasetItems(DatasetItem item, ChunkedOutput writer) @Operation(operationId = "createOrUpdateDatasetItems", summary = "Create/update dataset items", description = "Create/update dataset items based on dataset item id", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response createDatasetItems( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetItemBatch.class))) @JsonView({ DatasetItem.View.Write.class}) @NotNull @Valid DatasetItemBatch batch) { @@ -377,6 +383,7 @@ public Response createDatasetItems( @Operation(operationId = "deleteDatasetItems", summary = "Delete dataset items", description = "Delete dataset items", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response deleteDatasetItems( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetItemsDelete.class))) @NotNull @Valid DatasetItemsDelete request) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java index 816535f380..468199fd6a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java @@ -10,6 +10,7 @@ import com.comet.opik.domain.ExperimentService; import com.comet.opik.domain.IdGenerator; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; import io.dropwizard.jersey.errors.ErrorMessage; @@ -106,6 +107,7 @@ public Response get(@PathParam("id") UUID id) { @Operation(operationId = "createExperiment", summary = "Create experiment", description = "Create experiment", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/experiments/{id}", schema = @Schema(implementation = String.class))})}) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Experiment.class))) @JsonView(Experiment.View.Write.class) @NotNull @Valid Experiment experiment, @Context UriInfo uriInfo) { @@ -151,6 +153,7 @@ public Response getExperimentItem(@PathParam("id") UUID id) { @Path("/items") @Operation(operationId = "createExperimentItems", summary = "Create experiment items", description = "Create experiment items", responses = { @ApiResponse(responseCode = "204", description = "No content")}) + @RateLimited public Response createExperimentItems( @RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemsBatch.class))) @NotNull @Valid ExperimentItemsBatch request) { @@ -178,6 +181,7 @@ public Response createExperimentItems( @Operation(operationId = "deleteExperimentItems", summary = "Delete experiment items", description = "Delete experiment items", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response deleteExperimentItems( @RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemsDelete.class))) @NotNull @Valid ExperimentItemsDelete request) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java index c3a2d9ccc0..169c44dc61 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java @@ -6,6 +6,7 @@ import com.comet.opik.api.Page; import com.comet.opik.domain.FeedbackDefinitionService; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; @@ -100,6 +101,7 @@ public Response getById(@PathParam("id") @NotNull UUID id) { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/feedback-definitions/{feedbackId}", schema = @Schema(implementation = String.class))}) }) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackDefinition.class))) @JsonView({ FeedbackDefinition.View.Create.class}) @NotNull @Valid FeedbackDefinition feedbackDefinition, @@ -123,6 +125,7 @@ public Response create( @Operation(operationId = "updateFeedbackDefinition", summary = "Update feedback definition by id", description = "Update feedback definition by id", responses = { @ApiResponse(responseCode = "204", description = "No Content") }) + @RateLimited public Response update(final @PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackDefinition.class))) @JsonView({ FeedbackDefinition.View.Update.class}) @NotNull @Valid FeedbackDefinition feedbackDefinition) { @@ -142,6 +145,7 @@ public Response update(final @PathParam("id") UUID id, @Operation(operationId = "deleteFeedbackDefinitionById", summary = "Delete feedback definition by id", description = "Delete feedback definition by id", responses = { @ApiResponse(responseCode = "204", description = "No Content") }) + @RateLimited public Response deleteById(@PathParam("id") UUID id) { String workspaceId = requestContext.get().getWorkspaceId(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index f6e05d4b9c..09b95be1d9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -8,6 +8,7 @@ import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.domain.ProjectService; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; @@ -100,6 +101,7 @@ public Response getById(@PathParam("id") UUID id) { @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Project.class))) @JsonView(Project.View.Write.class) @Valid Project project, @Context UriInfo uriInfo) { @@ -125,6 +127,7 @@ public Response create( @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) + @RateLimited public Response update(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = ProjectUpdate.class))) @Valid ProjectUpdate project) { @@ -143,6 +146,7 @@ public Response update(@PathParam("id") UUID id, @ApiResponse(responseCode = "204", description = "No Content"), @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) + @RateLimited public Response deleteById(@PathParam("id") UUID id) { String workspaceId = requestContext.get().getWorkspaceId(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java index 53279ead0d..2a9cb7a18d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java @@ -14,6 +14,7 @@ import com.comet.opik.domain.SpanService; import com.comet.opik.domain.SpanType; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; @@ -126,6 +127,7 @@ public Response getById(@PathParam("id") @NotNull UUID id) { @Operation(operationId = "createSpan", summary = "Create span", description = "Create span", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/spans/{spanId}", schema = @Schema(implementation = String.class))})}) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Span.class))) @JsonView(Span.View.Write.class) @NotNull @Valid Span span, @Context UriInfo uriInfo) { @@ -148,6 +150,7 @@ public Response create( @Path("/batch") @Operation(operationId = "createSpans", summary = "Create spans", description = "Create spans", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response createSpans( @RequestBody(content = @Content(schema = @Schema(implementation = SpanBatch.class))) @JsonView(Span.View.Write.class) @NotNull @Valid SpanBatch spans) { @@ -173,6 +176,7 @@ public Response createSpans( @Operation(operationId = "updateSpan", summary = "Update span by id", description = "Update span by id", responses = { @ApiResponse(responseCode = "204", description = "No Content"), @ApiResponse(responseCode = "404", description = "Not found")}) + @RateLimited public Response update(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = SpanUpdate.class))) @NotNull @Valid SpanUpdate spanUpdate) { @@ -191,6 +195,7 @@ public Response update(@PathParam("id") UUID id, @Operation(operationId = "deleteSpanById", summary = "Delete span by id", description = "Delete span by id", responses = { @ApiResponse(responseCode = "501", description = "Not implemented"), @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response deleteById(@PathParam("id") @NotNull String id) { log.info("Deleting span with id '{}' on workspaceId '{}'", id, requestContext.get().getWorkspaceId()); @@ -201,6 +206,7 @@ public Response deleteById(@PathParam("id") @NotNull String id) { @Path("/{id}/feedback-scores") @Operation(operationId = "addSpanFeedbackScore", summary = "Add span feedback score", description = "Add span feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response addSpanFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScore.class))) @NotNull @Valid FeedbackScore score) { @@ -219,6 +225,7 @@ public Response addSpanFeedbackScore(@PathParam("id") UUID id, @Path("/{id}/feedback-scores/delete") @Operation(operationId = "deleteSpanFeedbackScore", summary = "Delete span feedback score", description = "Delete span feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response deleteSpanFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = DeleteFeedbackScore.class))) @NotNull @Valid DeleteFeedbackScore score) { @@ -236,6 +243,7 @@ public Response deleteSpanFeedbackScore(@PathParam("id") UUID id, @Path("/feedback-scores") @Operation(operationId = "scoreBatchOfSpans", summary = "Batch feedback scoring for spans", description = "Batch feedback scoring for spans", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response scoreBatchOfSpans( @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScoreBatch.class))) @NotNull @Valid FeedbackScoreBatch batch) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index 066c35119f..8ddbabea59 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -152,7 +152,8 @@ public Response create( @Path("/batch") @Operation(operationId = "createTraces", summary = "Create traces", description = "Create traces", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) - public Response createSpans( + @RateLimited + public Response createTraces( @RequestBody(content = @Content(schema = @Schema(implementation = TraceBatch.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid TraceBatch traces) { traces.traces() @@ -176,6 +177,7 @@ public Response createSpans( @Path("{id}") @Operation(operationId = "updateTrace", summary = "Update trace by id", description = "Update trace by id", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response update(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = TraceUpdate.class))) @Valid @NonNull TraceUpdate trace) { @@ -196,6 +198,7 @@ public Response update(@PathParam("id") UUID id, @Path("{id}") @Operation(operationId = "deleteTraceById", summary = "Delete trace by id", description = "Delete trace by id", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response deleteById(@PathParam("id") UUID id) { log.info("Deleting trace with id '{}'", id); @@ -213,6 +216,7 @@ public Response deleteById(@PathParam("id") UUID id) { @Path("/delete") @Operation(operationId = "deleteTraces", summary = "Delete traces", description = "Delete traces", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response deleteTraces( @RequestBody(content = @Content(schema = @Schema(implementation = TracesDelete.class))) @NotNull @Valid TracesDelete request) { log.info("Deleting traces, count '{}'", request.ids().size()); @@ -227,6 +231,7 @@ public Response deleteTraces( @Path("/{id}/feedback-scores") @Operation(operationId = "addTraceFeedbackScore", summary = "Add trace feedback score", description = "Add trace feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response addTraceFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScore.class))) @NotNull @Valid FeedbackScore score) { @@ -247,6 +252,7 @@ public Response addTraceFeedbackScore(@PathParam("id") UUID id, @Path("/{id}/feedback-scores/delete") @Operation(operationId = "deleteTraceFeedbackScore", summary = "Delete trace feedback score", description = "Delete trace feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response deleteTraceFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = DeleteFeedbackScore.class))) @NotNull @Valid DeleteFeedbackScore score) { @@ -267,6 +273,7 @@ public Response deleteTraceFeedbackScore(@PathParam("id") UUID id, @Path("/feedback-scores") @Operation(operationId = "scoreBatchOfTraces", summary = "Batch feedback scoring for traces", description = "Batch feedback scoring for traces", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response scoreBatchOfTraces( @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScoreBatch.class))) @NotNull @Valid FeedbackScoreBatch batch) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java index 0445a387ab..dbaaa04ca9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java @@ -9,7 +9,8 @@ @Data public class RateLimitConfig { - public record LimitConfig(@Valid @JsonProperty @PositiveOrZero long limit, @Valid @JsonProperty @Positive long durationInSeconds) { + public record LimitConfig(@Valid @JsonProperty @PositiveOrZero long limit, + @Valid @JsonProperty @Positive long durationInSeconds) { } @Valid diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java index e421dc9e5d..8267c06522 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -2,20 +2,22 @@ import com.comet.opik.infrastructure.RateLimitConfig; import com.comet.opik.infrastructure.auth.RequestContext; -import jakarta.inject.Inject; +import jakarta.inject.Provider; import jakarta.ws.rs.ClientErrorException; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; import reactor.core.publisher.Mono; import java.lang.reflect.Method; - -@RequiredArgsConstructor(onConstructor_ = @Inject) +@Slf4j +@RequiredArgsConstructor class RateLimitInterceptor implements MethodInterceptor { - private final RateLimitService rateLimitService; + private final Provider requestContext; + private final Provider rateLimitService; private final RateLimitConfig rateLimitConfig; @Override @@ -41,21 +43,38 @@ public Object invoke(MethodInvocation invocation) throws Throwable { long limit = rateLimitConfig.getGeneralEvents().limit(); long limitDurationInSeconds = rateLimitConfig.getGeneralEvents().durationInSeconds(); + String apiKey = requestContext.get().getApiKey(); - Boolean limitExceeded = Mono.deferContextual(context -> { - String apiKey = context.get(RequestContext.API_KEY); - - // Check if the rate limit is exceeded - return rateLimitService.isLimitExceeded(apiKey, bucket, limit, limitDurationInSeconds); - }).block(); - + // Check if the rate limit is exceeded + Boolean limitExceeded = rateLimitService.get() + .isLimitExceeded(apiKey, bucket, limit, limitDurationInSeconds) + .block(); if (Boolean.TRUE.equals(limitExceeded)) { - throw new ClientErrorException(429); + throw new ClientErrorException("Too Many Requests", 429); + } + + try { + return invocation.proceed(); + } catch (Exception ex) { + decreaseLimitInCaseOfError(bucket); + throw ex; } } return invocation.proceed(); } + private void decreaseLimitInCaseOfError(String bucket) { + try { + Mono.deferContextual(context -> { + String apiKey = context.get(RequestContext.API_KEY); + + return rateLimitService.get().decrement(apiKey, bucket); + }).subscribe(); + } catch (Exception ex) { + log.warn("Failed to decrement rate limit", ex); + } + } + } \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java index 2414fd0f67..693ee823a4 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java @@ -2,6 +2,7 @@ import com.comet.opik.infrastructure.OpikConfiguration; import com.comet.opik.infrastructure.RateLimitConfig; +import com.comet.opik.infrastructure.auth.RequestContext; import com.google.inject.matcher.Matchers; import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule; @@ -10,9 +11,11 @@ public class RateLimitModule extends DropwizardAwareModule { @Override protected void configure() { - var rateLimit = configuration(RateLimitService.class); + var rateLimit = getProvider(RateLimitService.class); var config = configuration(RateLimitConfig.class); - var rateLimitInterceptor = new RateLimitInterceptor(rateLimit, config); + var requestContext = getProvider(RequestContext.class); + + var rateLimitInterceptor = new RateLimitInterceptor(requestContext, rateLimit, config); bindInterceptor(Matchers.any(), Matchers.annotatedWith(RateLimited.class), rateLimitInterceptor); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java index 62a1aed5ed..ea0ab77397 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -5,4 +5,6 @@ public interface RateLimitService { Mono isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds); + + Mono decrement(String apiKey, String bucketName); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java index c9a440cf98..2ff9122dfd 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java @@ -9,7 +9,7 @@ @Retention(RetentionPolicy.RUNTIME) public @interface RateLimited { - String GENERAL_EVENTS = "general_events"; + String GENERAL_EVENTS = "general_events"; - String value() default GENERAL_EVENTS; // bucket capacity + String value() default GENERAL_EVENTS; // bucket capacity } \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java index c004b0a10c..e5d36b768d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -32,4 +32,11 @@ public Mono isLimitExceeded(String apiKey, String bucketName, long limi return Mono.just(count > limit); }); } + + @Override + public Mono decrement(String apiKey, String bucketName) { + RAtomicLongReactive limitInstance = redisClient.getAtomicLong(bucketName + ":" + apiKey); + return limitInstance.decrementAndGet().then(); + } + } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java index 1c339b9be4..ba7b3f7539 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java @@ -4,13 +4,28 @@ import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; import com.comet.opik.infrastructure.auth.TestHttpClientUtils; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import lombok.Builder; +import lombok.experimental.UtilityClass; import ru.vyarus.dropwizard.guice.hook.GuiceyConfigurationHook; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import java.util.ArrayList; +@UtilityClass public class TestDropwizardAppExtensionUtils { + @Builder + public record AppContextConfig( + String jdbcUrl, + DatabaseAnalyticsFactory databaseAnalyticsFactory, + WireMockRuntimeInfo runtimeInfo, + String redisUrl, + Integer cacheTtlInSeconds, + boolean rateLimitEnabled, + Long limit, + Long limitDurationInSeconds) { + } + public static TestDropwizardAppExtension newTestDropwizardAppExtension(String jdbcUrl, WireMockRuntimeInfo runtimeInfo) { return newTestDropwizardAppExtension(jdbcUrl, null, runtimeInfo); @@ -40,23 +55,36 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension( WireMockRuntimeInfo runtimeInfo, String redisUrl, Integer cacheTtlInSeconds) { + return newTestDropwizardAppExtension( + AppContextConfig.builder() + .jdbcUrl(jdbcUrl) + .databaseAnalyticsFactory(databaseAnalyticsFactory) + .runtimeInfo(runtimeInfo) + .redisUrl(redisUrl) + .cacheTtlInSeconds(cacheTtlInSeconds) + .build()); + } + + public static TestDropwizardAppExtension newTestDropwizardAppExtension(AppContextConfig appContextConfig) { var list = new ArrayList(); - list.add("database.url: " + jdbcUrl); + list.add("database.url: " + appContextConfig.jdbcUrl()); - if (databaseAnalyticsFactory != null) { - list.add("databaseAnalytics.port: " + databaseAnalyticsFactory.getPort()); - list.add("databaseAnalytics.username: " + databaseAnalyticsFactory.getUsername()); - list.add("databaseAnalytics.password: " + databaseAnalyticsFactory.getPassword()); + if (appContextConfig.databaseAnalyticsFactory() != null) { + list.add("databaseAnalytics.port: " + appContextConfig.databaseAnalyticsFactory().getPort()); + list.add("databaseAnalytics.username: " + appContextConfig.databaseAnalyticsFactory().getUsername()); + list.add("databaseAnalytics.password: " + appContextConfig.databaseAnalyticsFactory().getPassword()); } - if (runtimeInfo != null) { + if (appContextConfig.runtimeInfo() != null) { list.add("authentication.enabled: true"); - list.add("authentication.sdk.url: " + "%s/opik/auth".formatted(runtimeInfo.getHttpsBaseUrl())); - list.add("authentication.ui.url: " + "%s/opik/auth-session".formatted(runtimeInfo.getHttpsBaseUrl())); + list.add("authentication.sdk.url: " + + "%s/opik/auth".formatted(appContextConfig.runtimeInfo().getHttpsBaseUrl())); + list.add("authentication.ui.url: " + + "%s/opik/auth-session".formatted(appContextConfig.runtimeInfo().getHttpsBaseUrl())); - if (cacheTtlInSeconds != null) { - list.add("authentication.apiKeyResolutionCacheTTLInSec: " + cacheTtlInSeconds); + if (appContextConfig.cacheTtlInSeconds() != null) { + list.add("authentication.apiKeyResolutionCacheTTLInSec: " + appContextConfig.cacheTtlInSeconds()); } } @@ -64,12 +92,19 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension( injector.modulesOverride(TestHttpClientUtils.testAuthModule()); }; - if (redisUrl != null) { - list.add("redis.singleNodeUrl: %s".formatted(redisUrl)); + if (appContextConfig.redisUrl() != null) { + list.add("redis.singleNodeUrl: %s".formatted(appContextConfig.redisUrl())); list.add("redis.sentinelMode: false"); list.add("redis.lockTimeout: 500"); } + if (appContextConfig.rateLimitEnabled()) { + list.add("rateLimit.enabled: true"); + list.add("rateLimit.generalEvents.limit: %d".formatted(appContextConfig.limit())); + list.add("rateLimit.generalEvents.durationInSeconds: %d" + .formatted(appContextConfig.limitDurationInSeconds())); + } + return TestDropwizardAppExtension.forApp(OpikApplication.class) .config("src/test/resources/config-test.yml") .configOverrides(list.toArray(new String[0])) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index 537d8ccc94..9b6a410847 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -123,6 +123,7 @@ class DatasetsResourceTest { private static final TestDropwizardAppExtension app; private static final WireMockRuntime wireMock; + public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "experimentCount"}; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java new file mode 100644 index 0000000000..cc7ad4d20f --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -0,0 +1,336 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.api.Trace; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.podam.PodamFactoryUtils; +import com.redis.testcontainers.RedisContainer; +import io.reactivex.rxjava3.internal.operators.single.SingleDelay; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.containers.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import static com.comet.opik.api.Trace.TracePage; +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils.AppContextConfig; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.groupingBy; + +@Testcontainers(parallel = true) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Rate limit Resource Test") +class RateLimitE2ETest { + + private static final String BASE_RESOURCE_URI = "%s/v1/private/traces"; + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + private static final ClickHouseContainer CLICKHOUSE = ClickHouseContainerUtils.newClickHouseContainer(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + private static final WireMockUtils.WireMockRuntime wireMock; + + private static final long LIMIT = 10L; + private static final long LIMIT_DURATION_IN_SECONDS = 1L; + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + static { + MYSQL.start(); + CLICKHOUSE.start(); + REDIS.start(); + + wireMock = WireMockUtils.startWireMock(); + + var databaseAnalyticsFactory = ClickHouseContainerUtils.newDatabaseAnalyticsFactory( + CLICKHOUSE, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + AppContextConfig.builder() + .jdbcUrl(MYSQL.getJdbcUrl()) + .databaseAnalyticsFactory(databaseAnalyticsFactory) + .runtimeInfo(wireMock.runtimeInfo()) + .redisUrl(REDIS.getRedisURI()) + .rateLimitEnabled(true) + .limit(LIMIT) + .limitDurationInSeconds(LIMIT_DURATION_IN_SECONDS) + .build()); + } + + private String baseURI; + private ClientSupport client; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi) throws Exception { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + + ClientSupportUtils.config(client); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId, String user) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, user); + } + + private static void mockSessionCookieTargetWorkspace(String sessionToken, String workspaceName, String workspaceId, + String user) { + AuthTestUtils.mockSessionCookieTargetWorkspace(wireMock.server(), sessionToken, workspaceName, workspaceId, + user); + } + + @Test + @DisplayName("Rate limit: When using apiKey and limit is exceeded Then block remaining calls") + void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(LIMIT * 2, projectName, apiKey, workspaceName); + + // Verify that the rate limit is exceeded + Assertions.assertEquals(LIMIT, responseMap.get(429)); + Assertions.assertEquals(LIMIT, responseMap.get(201)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + // Verify that traces created are equal to the limit + Assertions.assertEquals(200, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + Assertions.assertEquals(LIMIT, page.content().size()); + Assertions.assertEquals(LIMIT, page.total()); + Assertions.assertEquals(LIMIT, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When using apiKey and limit is not exceeded given duration Then allow all calls") + void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllCalls() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); + + // Verify that the rate limit is not exceeded + Assertions.assertEquals(LIMIT, responseMap.get(201)); + + SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); + + responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); + + // Verify that the rate limit is not exceeded + Assertions.assertEquals(LIMIT, responseMap.get(201)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + // Verify that traces created are equal to the limit + Assertions.assertEquals(200, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + Assertions.assertEquals(LIMIT * 2, page.content().size()); + Assertions.assertEquals(LIMIT * 2, page.total()); + Assertions.assertEquals(LIMIT * 2, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When using sessionToken and limit is exceeded Then block remaining calls") + void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCalls() { + + String sessionToken = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockSessionCookieTargetWorkspace(sessionToken, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithCookie(LIMIT * 2, projectName, sessionToken, workspaceName); + + // Verify that the rate limit is exceeded + Assertions.assertEquals(LIMIT, responseMap.get(429)); + Assertions.assertEquals(LIMIT, responseMap.get(201)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(RequestContext.SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + // Verify that traces created are equal to the limit + Assertions.assertEquals(200, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + Assertions.assertEquals(LIMIT, page.content().size()); + Assertions.assertEquals(LIMIT, page.total()); + Assertions.assertEquals(LIMIT, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When using sessionToken and limit is not exceeded given duration Then allow all calls") + void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAllowAllCalls() { + + String sessionToken = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockSessionCookieTargetWorkspace(sessionToken, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); + + // Verify that the rate limit is not exceeded + Assertions.assertEquals(LIMIT, responseMap.get(201)); + + SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); + + responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); + + // Verify that the rate limit is not exceeded + Assertions.assertEquals(LIMIT, responseMap.get(201)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(RequestContext.SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + // Verify that traces created are equal to the limit + Assertions.assertEquals(200, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + Assertions.assertEquals(LIMIT * 2, page.content().size()); + Assertions.assertEquals(LIMIT * 2, page.total()); + Assertions.assertEquals(LIMIT * 2, page.size()); + } + + } + + private Map triggerCallsWithCookie(long limit, String projectName, String sessionToken, + String workspaceName) { + return Flux.range(0, ((int) limit)) + .flatMap(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(RequestContext.SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + return Flux.just(response); + } + }, 5) + .toStream() + .collect(groupingBy(Response::getStatus, counting())); + } + + private Map triggerCallsWithApiKey(long limit, String projectName, String apiKey, + String workspaceName) { + return Flux.range(0, ((int) limit)) + .flatMap(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + return Flux.just(response); + } + }, 5) + .toStream() + .collect(groupingBy(Response::getStatus, counting())); + } + +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java new file mode 100644 index 0000000000..20a8be36b3 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java @@ -0,0 +1,131 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.api.resources.v1.priv.DatasetsResource; +import com.comet.opik.api.resources.v1.priv.ExperimentsResource; +import com.comet.opik.api.resources.v1.priv.FeedbackDefinitionResource; +import com.comet.opik.api.resources.v1.priv.ProjectsResource; +import com.comet.opik.api.resources.v1.priv.SpansResource; +import com.comet.opik.api.resources.v1.priv.TracesResource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +class RateLimitSetupTest { + + @Test + void allEventFromDatasetsResourceShouldBeRateLimited() { + + // Given + boolean expectedOutput = Stream + .of("createDataset", "updateDataset", "deleteDataset", "deleteDatasetByName", "createDatasetItems", + "deleteDatasetItems") + .allMatch(methodName -> { + List methods = Arrays.stream(DatasetsResource.class.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + return !methods.isEmpty() && methods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + }); + + // Then + Assertions.assertTrue(expectedOutput); + } + + @Test + void allEventFromExperimentsResourceShouldBeRateLimited() { + + // Given + boolean expectedOutput = Stream.of("create", "createExperimentItems", "deleteExperimentItems") + .allMatch(methodName -> { + List methods = Arrays.stream(ExperimentsResource.class.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + return !methods.isEmpty() && methods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + }); + // Then + Assertions.assertTrue(expectedOutput); + } + + @Test + void allEventFromFeedbackResourceShouldBeRateLimited() { + + // Given + boolean expectedOutput = Stream.of("create", "update", "deleteById") + .allMatch(methodName -> { + List methods = Arrays.stream(FeedbackDefinitionResource.class.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + return !methods.isEmpty() && methods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + }); + // Then + Assertions.assertTrue(expectedOutput); + } + + @Test + void allEventFromProjectsResourceShouldBeRateLimited() { + + // Given + boolean expectedOutput = Stream.of("create", "update", "deleteById") + .allMatch(methodName -> { + List methods = Arrays.stream(ProjectsResource.class.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + return !methods.isEmpty() && methods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + }); + + // Then + Assertions.assertTrue(expectedOutput); + } + + @Test + void allEventFromSpansResourceShouldBeRateLimited() { + + // Given + boolean expectedOutput = Stream + .of("create", "createSpans", "update", "deleteById", "addSpanFeedbackScore", "deleteSpanFeedbackScore", + "scoreBatchOfSpans") + .allMatch(methodName -> { + List methods = Arrays.stream(SpansResource.class.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + return !methods.isEmpty() && methods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + }); + + // Then + Assertions.assertTrue(expectedOutput); + } + + @Test + void allEventFromTracesResourceShouldBeRateLimited() { + + // Given + boolean expectedOutput = Stream + .of("create", "createTraces", "update", "deleteById", "deleteTraces", "addTraceFeedbackScore", + "deleteTraceFeedbackScore", "scoreBatchOfTraces") + .allMatch(methodName -> { + List methods = Arrays.stream(TracesResource.class.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + return !methods.isEmpty() && methods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + }); + + // Then + Assertions.assertTrue(expectedOutput); + } + +} From e7718acf701360c803a2844c81d7ebd0e0e9ec4a Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 19 Sep 2024 00:06:35 +0200 Subject: [PATCH 3/6] Add batch endpoint rate limit --- .../com/comet/opik/api/DatasetItemBatch.java | 9 +- .../comet/opik/api/ExperimentItemsBatch.java | 10 +- .../comet/opik/api/FeedbackScoreBatch.java | 8 +- .../java/com/comet/opik/api/SpanBatch.java | 8 +- .../java/com/comet/opik/api/TraceBatch.java | 8 +- .../resources/v1/priv/DatasetsResource.java | 3 - .../v1/priv/ExperimentsResource.java | 1 - .../v1/priv/FeedbackDefinitionResource.java | 1 - .../resources/v1/priv/ProjectsResource.java | 1 - .../api/resources/v1/priv/SpansResource.java | 2 - .../api/resources/v1/priv/TracesResource.java | 3 - .../opik/domain/FeedbackScoreService.java | 2 +- .../com/comet/opik/domain/SpanService.java | 2 +- .../com/comet/opik/domain/TraceService.java | 2 +- .../opik/infrastructure/auth/AuthModule.java | 2 +- .../auth/RemoteAuthService.java | 4 +- .../{redis => lock}/LockService.java | 2 +- .../ratelimit/RateEventContainer.java | 7 + .../ratelimit/RateLimitInterceptor.java | 29 ++- .../ratelimit/RateLimitService.java | 5 +- .../infrastructure/ratelimit/RateLimited.java | 2 +- .../infrastructure/redis/RedisModule.java | 1 + .../redis/RedisRateLimitService.java | 88 +++++-- .../redis/RedissonLockService.java | 1 + .../comet/opik/domain/DummyLockService.java | 2 +- .../comet/opik/domain/SpanServiceTest.java | 2 +- .../opik/domain/TraceServiceImplTest.java | 2 +- .../ratelimit/RateLimitE2ETest.java | 220 +++++++++++++++++- .../ratelimit/RateLimitSetupTest.java | 106 +++------ .../RedissonLockServiceIntegrationTest.java | 1 + 30 files changed, 398 insertions(+), 136 deletions(-) rename apps/opik-backend/src/main/java/com/comet/opik/infrastructure/{redis => lock}/LockService.java (92%) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java index 6c3204d2f4..f03ee900fe 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java @@ -1,6 +1,7 @@ package com.comet.opik.api; import com.comet.opik.api.validate.DatasetItemBatchValidation; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.databind.PropertyNamingStrategies; @@ -26,6 +27,12 @@ public record DatasetItemBatch( DatasetItem.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Schema(description = "If null, dataset_id must be provided") String datasetName, @JsonView({ DatasetItem.View.Write.class}) @Schema(description = "If null, dataset_name must be provided") UUID datasetId, - @JsonView({DatasetItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid List items){ + @JsonView({DatasetItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid List items) + implements + RateEventContainer{ + @Override + public long eventCount() { + return items.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java index 54a9314585..771c517265 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.databind.PropertyNamingStrategies; @@ -16,5 +17,12 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record ExperimentItemsBatch( @JsonView( { - ExperimentItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid Set experimentItems){ + ExperimentItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid Set experimentItems) + implements + RateEventContainer{ + + @Override + public long eventCount() { + return experimentItems.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java index ab798e74ff..1fc5910efe 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; @@ -13,6 +14,11 @@ @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) -public record FeedbackScoreBatch(@NotNull @Size(min = 1, max = 1000) @Valid List scores) { +public record FeedbackScoreBatch( + @NotNull @Size(min = 1, max = 1000) @Valid List scores) implements RateEventContainer { + @Override + public long eventCount() { + return scores.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java index 02727bec49..74fa43253f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonView; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -10,5 +11,10 @@ @Builder(toBuilder = true) public record SpanBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( { - Span.View.Write.class}) @Valid List spans){ + Span.View.Write.class}) @Valid List spans) implements RateEventContainer{ + + @Override + public long eventCount() { + return spans.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java index 0765a89712..fafa64468a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonView; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -10,5 +11,10 @@ @Builder(toBuilder = true) public record TraceBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( { - Trace.View.Write.class}) @Valid List traces){ + Trace.View.Write.class}) @Valid List traces) implements RateEventContainer{ + + @Override + public long eventCount() { + return traces.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index 779dc7ada9..5f80960d18 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -175,7 +175,6 @@ public Response updateDataset(@PathParam("id") UUID id, @Operation(operationId = "deleteDataset", summary = "Delete dataset by id", description = "Delete dataset by id", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) - @RateLimited public Response deleteDataset(@PathParam("id") UUID id) { String workspaceId = requestContext.get().getWorkspaceId(); @@ -190,7 +189,6 @@ public Response deleteDataset(@PathParam("id") UUID id) { @Operation(operationId = "deleteDatasetByName", summary = "Delete dataset by name", description = "Delete dataset by name", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) - @RateLimited public Response deleteDatasetByName( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetIdentifier.class))) @NotNull @Valid DatasetIdentifier identifier) { @@ -383,7 +381,6 @@ public Response createDatasetItems( @Operation(operationId = "deleteDatasetItems", summary = "Delete dataset items", description = "Delete dataset items", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) - @RateLimited public Response deleteDatasetItems( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetItemsDelete.class))) @NotNull @Valid DatasetItemsDelete request) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java index 468199fd6a..31caf2a3cb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java @@ -181,7 +181,6 @@ public Response createExperimentItems( @Operation(operationId = "deleteExperimentItems", summary = "Delete experiment items", description = "Delete experiment items", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) - @RateLimited public Response deleteExperimentItems( @RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemsDelete.class))) @NotNull @Valid ExperimentItemsDelete request) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java index 169c44dc61..345503da12 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java @@ -145,7 +145,6 @@ public Response update(final @PathParam("id") UUID id, @Operation(operationId = "deleteFeedbackDefinitionById", summary = "Delete feedback definition by id", description = "Delete feedback definition by id", responses = { @ApiResponse(responseCode = "204", description = "No Content") }) - @RateLimited public Response deleteById(@PathParam("id") UUID id) { String workspaceId = requestContext.get().getWorkspaceId(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index 09b95be1d9..186d52022b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -146,7 +146,6 @@ public Response update(@PathParam("id") UUID id, @ApiResponse(responseCode = "204", description = "No Content"), @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) - @RateLimited public Response deleteById(@PathParam("id") UUID id) { String workspaceId = requestContext.get().getWorkspaceId(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java index 2a9cb7a18d..166e293d82 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java @@ -195,7 +195,6 @@ public Response update(@PathParam("id") UUID id, @Operation(operationId = "deleteSpanById", summary = "Delete span by id", description = "Delete span by id", responses = { @ApiResponse(responseCode = "501", description = "Not implemented"), @ApiResponse(responseCode = "204", description = "No Content")}) - @RateLimited public Response deleteById(@PathParam("id") @NotNull String id) { log.info("Deleting span with id '{}' on workspaceId '{}'", id, requestContext.get().getWorkspaceId()); @@ -225,7 +224,6 @@ public Response addSpanFeedbackScore(@PathParam("id") UUID id, @Path("/{id}/feedback-scores/delete") @Operation(operationId = "deleteSpanFeedbackScore", summary = "Delete span feedback score", description = "Delete span feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) - @RateLimited public Response deleteSpanFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = DeleteFeedbackScore.class))) @NotNull @Valid DeleteFeedbackScore score) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index 8ddbabea59..9da1ab6697 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -198,7 +198,6 @@ public Response update(@PathParam("id") UUID id, @Path("{id}") @Operation(operationId = "deleteTraceById", summary = "Delete trace by id", description = "Delete trace by id", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) - @RateLimited public Response deleteById(@PathParam("id") UUID id) { log.info("Deleting trace with id '{}'", id); @@ -216,7 +215,6 @@ public Response deleteById(@PathParam("id") UUID id) { @Path("/delete") @Operation(operationId = "deleteTraces", summary = "Delete traces", description = "Delete traces", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) - @RateLimited public Response deleteTraces( @RequestBody(content = @Content(schema = @Schema(implementation = TracesDelete.class))) @NotNull @Valid TracesDelete request) { log.info("Deleting traces, count '{}'", request.ids().size()); @@ -252,7 +250,6 @@ public Response addTraceFeedbackScore(@PathParam("id") UUID id, @Path("/{id}/feedback-scores/delete") @Operation(operationId = "deleteTraceFeedbackScore", summary = "Delete trace feedback score", description = "Delete trace feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) - @RateLimited public Response deleteTraceFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = DeleteFeedbackScore.class))) @NotNull @Valid DeleteFeedbackScore score) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java index 3c2569ba61..a15de2c8a1 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java @@ -8,7 +8,7 @@ import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplate; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.utils.WorkspaceUtils; import com.google.inject.ImplementedBy; import com.google.inject.Singleton; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java index 27caa0053f..315f20eee9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java @@ -10,7 +10,7 @@ import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.utils.WorkspaceUtils; import com.google.common.base.Preconditions; import com.newrelic.api.agent.Trace; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java index 956b01eb36..35dfded1da 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java @@ -11,7 +11,7 @@ import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplate; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.utils.AsyncUtils; import com.comet.opik.utils.WorkspaceUtils; import com.google.common.base.Preconditions; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java index 757a994ec3..f9cd9a715f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java @@ -2,7 +2,7 @@ import com.comet.opik.infrastructure.AuthenticationConfig; import com.comet.opik.infrastructure.OpikConfiguration; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.google.common.base.Preconditions; import com.google.inject.Provides; import jakarta.inject.Provider; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java index 635ac00aba..68b555392a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java @@ -1,7 +1,7 @@ package com.comet.opik.infrastructure.auth; import com.comet.opik.domain.ProjectService; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import jakarta.inject.Provider; import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.client.Client; @@ -22,7 +22,7 @@ import static com.comet.opik.infrastructure.AuthenticationConfig.UrlConfig; import static com.comet.opik.infrastructure.auth.AuthCredentialsCacheService.AuthCredentials; -import static com.comet.opik.infrastructure.redis.LockService.Lock; +import static com.comet.opik.infrastructure.lock.LockService.Lock; @RequiredArgsConstructor @Slf4j diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/LockService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/lock/LockService.java similarity index 92% rename from apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/LockService.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/lock/LockService.java index a02bba19d5..5c9d6a2f54 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/LockService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/lock/LockService.java @@ -1,4 +1,4 @@ -package com.comet.opik.infrastructure.redis; +package com.comet.opik.infrastructure.lock; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java new file mode 100644 index 0000000000..a7f07dc2dc --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java @@ -0,0 +1,7 @@ +package com.comet.opik.infrastructure.ratelimit; + +public interface RateEventContainer { + + long eventCount(); + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java index 8267c06522..538cbb9a83 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -2,6 +2,7 @@ import com.comet.opik.infrastructure.RateLimitConfig; import com.comet.opik.infrastructure.auth.RequestContext; +import io.swagger.v3.oas.annotations.parameters.RequestBody; import jakarta.inject.Provider; import jakarta.ws.rs.ClientErrorException; import lombok.RequiredArgsConstructor; @@ -31,23 +32,26 @@ public Object invoke(MethodInvocation invocation) throws Throwable { return invocation.proceed(); } - RateLimited rateLimit = method.getAnnotation(RateLimited.class); - String bucket = rateLimit.value(); - if (!rateLimitConfig.isEnabled()) { return invocation.proceed(); } + RateLimited rateLimit = method.getAnnotation(RateLimited.class); + String bucket = rateLimit.value(); + // Check if the bucket is the general events bucket if (bucket.equals(RateLimited.GENERAL_EVENTS)) { + Object body = getParameters(invocation); + long events = body instanceof RateEventContainer container ? container.eventCount() : 1; + long limit = rateLimitConfig.getGeneralEvents().limit(); long limitDurationInSeconds = rateLimitConfig.getGeneralEvents().durationInSeconds(); String apiKey = requestContext.get().getApiKey(); // Check if the rate limit is exceeded Boolean limitExceeded = rateLimitService.get() - .isLimitExceeded(apiKey, bucket, limit, limitDurationInSeconds) + .isLimitExceeded(apiKey, events, bucket, limit, limitDurationInSeconds) .block(); if (Boolean.TRUE.equals(limitExceeded)) { @@ -57,7 +61,7 @@ public Object invoke(MethodInvocation invocation) throws Throwable { try { return invocation.proceed(); } catch (Exception ex) { - decreaseLimitInCaseOfError(bucket); + decreaseLimitInCaseOfError(bucket, events); throw ex; } } @@ -65,12 +69,23 @@ public Object invoke(MethodInvocation invocation) throws Throwable { return invocation.proceed(); } - private void decreaseLimitInCaseOfError(String bucket) { + private Object getParameters(MethodInvocation method) { + + for (int i = 0; i < method.getArguments().length; i++) { + if (method.getMethod().getParameters()[i].isAnnotationPresent(RequestBody.class)) { + return method.getArguments()[i]; + } + } + + return null; + } + + private void decreaseLimitInCaseOfError(String bucket, Long events) { try { Mono.deferContextual(context -> { String apiKey = context.get(RequestContext.API_KEY); - return rateLimitService.get().decrement(apiKey, bucket); + return rateLimitService.get().decrement(apiKey, bucket, events); }).subscribe(); } catch (Exception ex) { log.warn("Failed to decrement rate limit", ex); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java index ea0ab77397..20879e7450 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -4,7 +4,8 @@ public interface RateLimitService { - Mono isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds); + Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, + long limitDurationInSeconds); - Mono decrement(String apiKey, String bucketName); + Mono decrement(String apiKey, String bucketName, long events); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java index 2ff9122dfd..e9671f287c 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java @@ -12,4 +12,4 @@ String GENERAL_EVENTS = "general_events"; String value() default GENERAL_EVENTS; // bucket capacity -} \ No newline at end of file +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java index 3def3365e0..adc5b651cc 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java @@ -3,6 +3,7 @@ import com.comet.opik.infrastructure.DistributedLockConfig; import com.comet.opik.infrastructure.OpikConfiguration; import com.comet.opik.infrastructure.RedisConfig; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.infrastructure.ratelimit.RateLimitService; import com.google.inject.Provides; import jakarta.inject.Singleton; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java index e5d36b768d..b2b02389d2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -1,14 +1,60 @@ package com.comet.opik.infrastructure.redis; import com.comet.opik.infrastructure.ratelimit.RateLimitService; -import org.redisson.api.RAtomicLongReactive; +import lombok.NonNull; +import org.redisson.api.RScript; import org.redisson.api.RedissonReactiveClient; +import org.redisson.client.codec.StringCodec; import reactor.core.publisher.Mono; -import java.time.Duration; +import java.util.List; public class RedisRateLimitService implements RateLimitService { + private static final String LUA_SCRIPT_ADD = """ + local current = redis.call('GET', KEYS[1]) + if not current then + current = 0 + else + current = tonumber(current) + end + + local limit = tonumber(ARGV[1]) + local increment = tonumber(ARGV[2]) + local ttl = tonumber(ARGV[3]) + + if (current + increment) > limit then + return 0 -- Failure, limit exceeded + else + redis.call('INCRBY', KEYS[1], increment) + + if redis.call('TTL', KEYS[1]) == -1 then + redis.call('EXPIRE', KEYS[1], ttl, 'NX') + end + + return 1 -- Success, increment done + end + """; + + private static final String LUA_SCRIPT_DECR = """ + local current = redis.call('GET', KEYS[1]) + if not current then + current = 0 + else + current = tonumber(current) + end + + local decrement = tonumber(ARGV[1]) + + if (current - decrement) > 0 then + redis.call('DECRBY', KEYS[1], decrement) + else + redis.call('SET', KEYS[1], 0) + end + + return 'OK' + """; + private final RedissonReactiveClient redisClient; public RedisRateLimitService(RedissonReactiveClient redisClient) { @@ -16,27 +62,33 @@ public RedisRateLimitService(RedissonReactiveClient redisClient) { } @Override - public Mono isLimitExceeded(String apiKey, String bucketName, long limit, long limitDurationInSeconds) { + public Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, + long limitDurationInSeconds) { - RAtomicLongReactive limitInstance = redisClient.getAtomicLong(bucketName + ":" + apiKey); + Mono eval = redisClient.getScript(StringCodec.INSTANCE).eval( + RScript.Mode.READ_WRITE, + LUA_SCRIPT_ADD, + RScript.ReturnType.INTEGER, + List.of(bucketName + ":" + apiKey), + limit, + events, + limitDurationInSeconds); - return limitInstance - .incrementAndGet() - .flatMap(count -> { - - if (count == 1) { - return limitInstance.expire(Duration.ofSeconds(limitDurationInSeconds)) - .map(__ -> count > limit); - } - - return Mono.just(count > limit); - }); + return eval.map(result -> result == 0); } @Override - public Mono decrement(String apiKey, String bucketName) { - RAtomicLongReactive limitInstance = redisClient.getAtomicLong(bucketName + ":" + apiKey); - return limitInstance.decrementAndGet().then(); + public Mono decrement(@NonNull String apiKey, @NonNull String bucketName, long events) { + Mono eval = redisClient.getScript(StringCodec.INSTANCE).eval( + RScript.Mode.READ_WRITE, + LUA_SCRIPT_DECR, + RScript.ReturnType.VALUE, + List.of(bucketName + ":" + apiKey), + events); + + return eval.map("OK"::equals) + .switchIfEmpty(Mono.error(new IllegalStateException("Rate limit bucket not found"))) + .then(); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java index 5ef02eac86..61589333a0 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java @@ -1,6 +1,7 @@ package com.comet.opik.infrastructure.redis; import com.comet.opik.infrastructure.DistributedLockConfig; +import com.comet.opik.infrastructure.lock.LockService; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java index faf96f9979..90e3354322 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java @@ -1,6 +1,6 @@ package com.comet.opik.domain; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java index c5e404b98a..31a4184f66 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java @@ -2,7 +2,7 @@ import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.error.InvalidUUIDVersionException; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.podam.PodamFactoryUtils; import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.impl.TimeBasedEpochGenerator; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java index 2188190fdc..cce25f38bd 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java @@ -8,7 +8,7 @@ import com.comet.opik.api.error.InvalidUUIDVersionException; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplate; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.fasterxml.uuid.Generators; import io.r2dbc.spi.Connection; import org.junit.jupiter.api.Assertions; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java index cc7ad4d20f..e6f90f921d 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -1,6 +1,15 @@ package com.comet.opik.infrastructure.ratelimit; +import com.comet.opik.api.DatasetItem; +import com.comet.opik.api.DatasetItemBatch; +import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemsBatch; +import com.comet.opik.api.FeedbackScoreBatch; +import com.comet.opik.api.FeedbackScoreBatchItem; +import com.comet.opik.api.Span; +import com.comet.opik.api.SpanBatch; import com.comet.opik.api.Trace; +import com.comet.opik.api.TraceBatch; import com.comet.opik.api.resources.utils.AuthTestUtils; import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; import com.comet.opik.api.resources.utils.ClientSupportUtils; @@ -12,8 +21,11 @@ import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.podam.PodamFactoryUtils; import com.redis.testcontainers.RedisContainer; +import io.dropwizard.jersey.errors.ErrorMessage; import io.reactivex.rxjava3.internal.operators.single.SingleDelay; +import jakarta.ws.rs.HttpMethod; import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.client.Invocation; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; @@ -25,6 +37,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.testcontainers.containers.ClickHouseContainer; import org.testcontainers.containers.MySQLContainer; import org.testcontainers.junit.jupiter.Testcontainers; @@ -33,9 +48,14 @@ import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; import static com.comet.opik.api.Trace.TracePage; import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; @@ -122,7 +142,7 @@ private static void mockSessionCookieTargetWorkspace(String sessionToken, String } @Test - @DisplayName("Rate limit: When using apiKey and limit is exceeded Then block remaining calls") + @DisplayName("Rate limit: When using apiKey and limit is exceeded, Then block remaining calls") void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { String apiKey = UUID.randomUUID().toString(); @@ -136,7 +156,6 @@ void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { Map responseMap = triggerCallsWithApiKey(LIMIT * 2, projectName, apiKey, workspaceName); - // Verify that the rate limit is exceeded Assertions.assertEquals(LIMIT, responseMap.get(429)); Assertions.assertEquals(LIMIT, responseMap.get(201)); @@ -161,7 +180,7 @@ void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { } @Test - @DisplayName("Rate limit: When using apiKey and limit is not exceeded given duration Then allow all calls") + @DisplayName("Rate limit: When using apiKey and limit is not exceeded given duration, Then allow all calls") void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllCalls() { String apiKey = UUID.randomUUID().toString(); @@ -175,14 +194,12 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC Map responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); - // Verify that the rate limit is not exceeded Assertions.assertEquals(LIMIT, responseMap.get(201)); SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); - // Verify that the rate limit is not exceeded Assertions.assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -194,7 +211,6 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC .header(WORKSPACE_HEADER, workspaceName) .get()) { - // Verify that traces created are equal to the limit Assertions.assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); @@ -206,7 +222,7 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC } @Test - @DisplayName("Rate limit: When using sessionToken and limit is exceeded Then block remaining calls") + @DisplayName("Rate limit: When using sessionToken and limit is exceeded, Then block remaining calls") void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCalls() { String sessionToken = UUID.randomUUID().toString(); @@ -220,7 +236,6 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal Map responseMap = triggerCallsWithCookie(LIMIT * 2, projectName, sessionToken, workspaceName); - // Verify that the rate limit is exceeded Assertions.assertEquals(LIMIT, responseMap.get(429)); Assertions.assertEquals(LIMIT, responseMap.get(201)); @@ -233,7 +248,6 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal .header(WORKSPACE_HEADER, workspaceName) .get()) { - // Verify that traces created are equal to the limit Assertions.assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); @@ -245,7 +259,7 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal } @Test - @DisplayName("Rate limit: When using sessionToken and limit is not exceeded given duration Then allow all calls") + @DisplayName("Rate limit: When using sessionToken and limit is not exceeded given duration, Then allow all calls") void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAllowAllCalls() { String sessionToken = UUID.randomUUID().toString(); @@ -259,14 +273,12 @@ void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAll Map responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); - // Verify that the rate limit is not exceeded Assertions.assertEquals(LIMIT, responseMap.get(201)); SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); - // Verify that the rate limit is not exceeded Assertions.assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -289,6 +301,190 @@ void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAll } + @Test + @DisplayName("Rate limit: When remaining limit is less than the batch size; Then reject the request") + void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + Assertions.assertEquals(1, responseMap.get(201)); + + List traces = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(null) + .build()) + .toList(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces)))) { + + Assertions.assertEquals(429, response.getStatus()); + var error = response.readEntity(ErrorMessage.class); + Assertions.assertEquals("Too Many Requests", error.getMessage()); + } + } + + @Test + @DisplayName("Rate limit: When after reject request due to batch size; Then accept the request with remaining limit") + void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRemainingLimit() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + Assertions.assertEquals(1, responseMap.get(201)); + + List traces = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(null) + .build()) + .toList(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces)))) { + + Assertions.assertEquals(429, response.getStatus()); + var error = response.readEntity(ErrorMessage.class); + Assertions.assertEquals("Too Many Requests", error.getMessage()); + } + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces.subList(0, (int) LIMIT - 1))))) { + + Assertions.assertEquals(204, response.getStatus()); + } + } + + @ParameterizedTest + @MethodSource + @DisplayName("Rate limit: When batch endpoint consumer remaining limit; Then reject next request") + void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( + Object batch, + Object batch2, + String url, + String method) { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + Invocation.Builder request = client.target(url) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName); + + try (var response = request.method(method, Entity.json(batch))) { + + Assertions.assertEquals(204, response.getStatus()); + } + + try (var response = request.method(method, Entity.json(batch2))) { + + Assertions.assertEquals(429, response.getStatus()); + var error = response.readEntity(ErrorMessage.class); + Assertions.assertEquals("Too Many Requests", error.getMessage()); + } + } + + public Stream rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest() { + + var projectName = UUID.randomUUID().toString(); + + var traces = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(null) + .build()) + .toList(); + + var spans = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Span.class).toBuilder() + .projectName(projectName) + .projectId(null) + .parentSpanId(null) + .build()) + .toList(); + + var datasetItems = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(DatasetItem.class).toBuilder() + .experimentItems(null) + .build()) + .toList(); + + var tracesFeedbackScores = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() + .projectId(null) + .build()) + .toList(); + + var spansFeedbackScores = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() + .projectId(null) + .build()) + .toList(); + + var experimentItems = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(ExperimentItem.class).toBuilder() + .feedbackScores(null) + .build()) + .collect(Collectors.toSet()); + + return Stream.of( + Arguments.of(new TraceBatch(traces), new TraceBatch(List.of(traces.getFirst())), + BASE_RESOURCE_URI.formatted(baseURI) + "/batch", HttpMethod.POST), + Arguments.of(new SpanBatch(spans), new SpanBatch(List.of(spans.getFirst())), + "%s/v1/private/spans".formatted(baseURI) + "/batch", HttpMethod.POST), + Arguments.of(new DatasetItemBatch(projectName, null, datasetItems), + new DatasetItemBatch(projectName, null, List.of(datasetItems.getFirst())), + "%s/v1/private/datasets".formatted(baseURI) + "/items", HttpMethod.PUT), + Arguments.of(new FeedbackScoreBatch(tracesFeedbackScores), + new FeedbackScoreBatch(List.of(tracesFeedbackScores.getFirst())), + BASE_RESOURCE_URI.formatted(baseURI) + "/feedback-scores", HttpMethod.PUT), + Arguments.of(new FeedbackScoreBatch(spansFeedbackScores), + new FeedbackScoreBatch(List.of(spansFeedbackScores.getFirst())), + "%s/v1/private/spans".formatted(baseURI) + "/feedback-scores", HttpMethod.PUT), + Arguments.of(new ExperimentItemsBatch(experimentItems), + new ExperimentItemsBatch(Set.of(experimentItems.stream().findFirst().orElseThrow())), + "%s/v1/private/experiments".formatted(baseURI) + "/items", HttpMethod.POST)); + } + private Map triggerCallsWithCookie(long limit, String projectName, String sessionToken, String workspaceName) { return Flux.range(0, ((int) limit)) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java index 20a8be36b3..c763ce3c00 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java @@ -20,112 +20,78 @@ class RateLimitSetupTest { void allEventFromDatasetsResourceShouldBeRateLimited() { // Given - boolean expectedOutput = Stream - .of("createDataset", "updateDataset", "deleteDataset", "deleteDatasetByName", "createDatasetItems", - "deleteDatasetItems") - .allMatch(methodName -> { - List methods = Arrays.stream(DatasetsResource.class.getMethods()) - .filter(method -> method.getName().equals(methodName)) - .toList(); - - return !methods.isEmpty() && methods.stream() - .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + Stream.of("createDataset", "updateDataset", "createDatasetItems") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, DatasetsResource.class); }); - - // Then - Assertions.assertTrue(expectedOutput); } @Test void allEventFromExperimentsResourceShouldBeRateLimited() { // Given - boolean expectedOutput = Stream.of("create", "createExperimentItems", "deleteExperimentItems") - .allMatch(methodName -> { - List methods = Arrays.stream(ExperimentsResource.class.getMethods()) - .filter(method -> method.getName().equals(methodName)) - .toList(); - - return !methods.isEmpty() && methods.stream() - .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + Stream.of("create", "createExperimentItems") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, ExperimentsResource.class); }); - // Then - Assertions.assertTrue(expectedOutput); + } + + private void assertIfMethodAreAnnotated(String methodName, Class targetClass) { + List targetMethods = Arrays.stream(targetClass.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + boolean actualMatch = !targetMethods.isEmpty() && targetMethods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + + Assertions.assertTrue(actualMatch, + "Method %s.%s is not annotated".formatted(targetClass.getSimpleName(), methodName)); } @Test void allEventFromFeedbackResourceShouldBeRateLimited() { // Given - boolean expectedOutput = Stream.of("create", "update", "deleteById") - .allMatch(methodName -> { - List methods = Arrays.stream(FeedbackDefinitionResource.class.getMethods()) - .filter(method -> method.getName().equals(methodName)) - .toList(); - - return !methods.isEmpty() && methods.stream() - .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + Stream.of("create", "update") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, FeedbackDefinitionResource.class); }); - // Then - Assertions.assertTrue(expectedOutput); } @Test void allEventFromProjectsResourceShouldBeRateLimited() { // Given - boolean expectedOutput = Stream.of("create", "update", "deleteById") - .allMatch(methodName -> { - List methods = Arrays.stream(ProjectsResource.class.getMethods()) - .filter(method -> method.getName().equals(methodName)) - .toList(); - - return !methods.isEmpty() && methods.stream() - .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + Stream.of("create", "update") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, ProjectsResource.class); }); - - // Then - Assertions.assertTrue(expectedOutput); } @Test void allEventFromSpansResourceShouldBeRateLimited() { // Given - boolean expectedOutput = Stream - .of("create", "createSpans", "update", "deleteById", "addSpanFeedbackScore", "deleteSpanFeedbackScore", - "scoreBatchOfSpans") - .allMatch(methodName -> { - List methods = Arrays.stream(SpansResource.class.getMethods()) - .filter(method -> method.getName().equals(methodName)) - .toList(); - - return !methods.isEmpty() && methods.stream() - .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + Stream.of("create", "createSpans", "update", "addSpanFeedbackScore", "scoreBatchOfSpans") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, SpansResource.class); }); - - // Then - Assertions.assertTrue(expectedOutput); } @Test void allEventFromTracesResourceShouldBeRateLimited() { // Given - boolean expectedOutput = Stream - .of("create", "createTraces", "update", "deleteById", "deleteTraces", "addTraceFeedbackScore", - "deleteTraceFeedbackScore", "scoreBatchOfTraces") - .allMatch(methodName -> { - List methods = Arrays.stream(TracesResource.class.getMethods()) - .filter(method -> method.getName().equals(methodName)) - .toList(); - - return !methods.isEmpty() && methods.stream() - .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + Stream.of("create", "createTraces", "update", "addTraceFeedbackScore", "scoreBatchOfTraces") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, TracesResource.class); }); - - // Then - Assertions.assertTrue(expectedOutput); } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java index 728921b170..660158afa5 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java @@ -4,6 +4,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.infrastructure.lock.LockService; import com.redis.testcontainers.RedisContainer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; From f9a537be64f1ab3c1bbe1b57e1a5cdb64d07eb16 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 19 Sep 2024 16:33:52 +0200 Subject: [PATCH 4/6] Add new rate limit headers --- apps/opik-backend/config.yml | 2 +- .../opik/infrastructure/auth/AuthFilter.java | 3 +- .../infrastructure/auth/RequestContext.java | 5 + .../ratelimit/RateLimitInterceptor.java | 19 +- .../ratelimit/RateLimitResponseFilter.java | 36 ++++ .../ratelimit/RateLimitService.java | 4 + .../redis/RedisRateLimitService.java | 87 +++------ .../ratelimit/RateLimitE2ETest.java | 175 ++++++++++++++---- 8 files changed, 221 insertions(+), 110 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index d70b40d0a1..25e11fd1b7 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -70,4 +70,4 @@ rateLimit: enabled: ${RATE_LIMIT_ENABLED:-false} generalEvents: limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-5000} - durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION:-1} \ No newline at end of file + durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION_SEC:-1} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java index add7358835..e5e2540497 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java @@ -23,6 +23,7 @@ public class AuthFilter implements ContainerRequestFilter { private final AuthService authService; + private final jakarta.inject.Provider requestContext; @Override public void filter(ContainerRequestContext context) throws IOException { @@ -36,7 +37,7 @@ public void filter(ContainerRequestContext context) throws IOException { if (Pattern.matches("/v1/private/.*", requestUri.getPath())) { authService.authenticate(headers, sessionToken); } - + requestContext.get().setHeaders(context.getHeaders()); } HttpHeaders getHttpHeaders(ContainerRequestContext context) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java index 64f964e827..a9f174a2ae 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java @@ -1,6 +1,7 @@ package com.comet.opik.infrastructure.auth; import com.google.inject.servlet.RequestScoped; +import jakarta.ws.rs.core.MultivaluedMap; import lombok.Data; @RequestScoped @@ -13,9 +14,13 @@ public class RequestContext { public static final String SESSION_COOKIE = "sessionToken"; public static final String WORKSPACE_ID = "workspaceId"; public static final String API_KEY = "apiKey"; + public static final String USER_LIMIT = "Opik-User-Limit"; + public static final String USER_REMAINING_LIMIT = "Opik-User-Remaining-Limit"; + public static final String USER_LIMIT_REMAINING_TTL = "Opik-User-Remaining-Limit-TTL-Millis"; private String userName; private String workspaceName; private String workspaceId; private String apiKey; + private MultivaluedMap headers; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java index 538cbb9a83..7de4f45006 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -9,9 +9,9 @@ import lombok.extern.slf4j.Slf4j; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; -import reactor.core.publisher.Mono; import java.lang.reflect.Method; +import java.util.List; @Slf4j @RequiredArgsConstructor @@ -55,6 +55,7 @@ public Object invoke(MethodInvocation invocation) throws Throwable { .block(); if (Boolean.TRUE.equals(limitExceeded)) { + setLimitHeaders(apiKey, bucket); throw new ClientErrorException("Too Many Requests", 429); } @@ -63,12 +64,20 @@ public Object invoke(MethodInvocation invocation) throws Throwable { } catch (Exception ex) { decreaseLimitInCaseOfError(bucket, events); throw ex; + } finally { + setLimitHeaders(apiKey, bucket); } } return invocation.proceed(); } + private void setLimitHeaders(String apiKey, String bucket) { + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(RateLimited.GENERAL_EVENTS)); + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block())); + requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block())); + } + private Object getParameters(MethodInvocation method) { for (int i = 0; i < method.getArguments().length; i++) { @@ -82,11 +91,9 @@ private Object getParameters(MethodInvocation method) { private void decreaseLimitInCaseOfError(String bucket, Long events) { try { - Mono.deferContextual(context -> { - String apiKey = context.get(RequestContext.API_KEY); - - return rateLimitService.get().decrement(apiKey, bucket, events); - }).subscribe(); + String apiKey = requestContext.get().getApiKey(); + rateLimitService.get().decrement(apiKey, bucket, events) + .subscribe(); } catch (Exception ex) { log.warn("Failed to decrement rate limit", ex); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java new file mode 100644 index 0000000000..186a5bb432 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java @@ -0,0 +1,36 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.auth.RequestContext; +import jakarta.inject.Inject; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerResponseContext; +import jakarta.ws.rs.container.ContainerResponseFilter; +import jakarta.ws.rs.ext.Provider; +import lombok.RequiredArgsConstructor; + +import java.io.IOException; +import java.util.List; + +@Provider +@RequiredArgsConstructor(onConstructor_ = @Inject) +public class RateLimitResponseFilter implements ContainerResponseFilter { + + @Override + public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) throws IOException { + List userLimit = getValueFromHeader(requestContext, RequestContext.USER_LIMIT); + List remainingLimit = getValueFromHeader(requestContext, RequestContext.USER_REMAINING_LIMIT); + List remainingTtl = getValueFromHeader(requestContext, RequestContext.USER_LIMIT_REMAINING_TTL); + + responseContext.getHeaders().put(RequestContext.USER_LIMIT, userLimit); + responseContext.getHeaders().put(RequestContext.USER_REMAINING_LIMIT, remainingLimit); + responseContext.getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, remainingTtl); + } + + private List getValueFromHeader(ContainerRequestContext requestContext, String key) { + return requestContext.getHeaders().getOrDefault(key, List.of()) + .stream() + .map(Object.class::cast) + .toList(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java index 20879e7450..1b9cbb9014 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -8,4 +8,8 @@ Mono isLimitExceeded(String apiKey, long events, String bucketName, lon long limitDurationInSeconds); Mono decrement(String apiKey, String bucketName, long events); + + Mono availableEvents(String apiKey, String bucketName); + + Mono getRemainingTTL(String apiKey, String bucket); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java index b2b02389d2..30f9dbaca9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -2,58 +2,17 @@ import com.comet.opik.infrastructure.ratelimit.RateLimitService; import lombok.NonNull; -import org.redisson.api.RScript; +import org.redisson.api.RRateLimiterReactive; +import org.redisson.api.RateIntervalUnit; +import org.redisson.api.RateType; import org.redisson.api.RedissonReactiveClient; -import org.redisson.client.codec.StringCodec; import reactor.core.publisher.Mono; -import java.util.List; +import java.time.Duration; public class RedisRateLimitService implements RateLimitService { - private static final String LUA_SCRIPT_ADD = """ - local current = redis.call('GET', KEYS[1]) - if not current then - current = 0 - else - current = tonumber(current) - end - - local limit = tonumber(ARGV[1]) - local increment = tonumber(ARGV[2]) - local ttl = tonumber(ARGV[3]) - - if (current + increment) > limit then - return 0 -- Failure, limit exceeded - else - redis.call('INCRBY', KEYS[1], increment) - - if redis.call('TTL', KEYS[1]) == -1 then - redis.call('EXPIRE', KEYS[1], ttl, 'NX') - end - - return 1 -- Success, increment done - end - """; - - private static final String LUA_SCRIPT_DECR = """ - local current = redis.call('GET', KEYS[1]) - if not current then - current = 0 - else - current = tonumber(current) - end - - local decrement = tonumber(ARGV[1]) - - if (current - decrement) > 0 then - redis.call('DECRBY', KEYS[1], decrement) - else - redis.call('SET', KEYS[1], 0) - end - - return 'OK' - """; + private static final String KEY = "%s:%s"; private final RedissonReactiveClient redisClient; @@ -65,30 +24,30 @@ public RedisRateLimitService(RedissonReactiveClient redisClient) { public Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, long limitDurationInSeconds) { - Mono eval = redisClient.getScript(StringCodec.INSTANCE).eval( - RScript.Mode.READ_WRITE, - LUA_SCRIPT_ADD, - RScript.ReturnType.INTEGER, - List.of(bucketName + ":" + apiKey), - limit, - events, - limitDurationInSeconds); + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); - return eval.map(result -> result == 0); + return rateLimit.trySetRate(RateType.OVERALL, limit, limitDurationInSeconds, RateIntervalUnit.SECONDS) + .then(Mono.defer(() -> rateLimit.expireIfNotSet(Duration.ofSeconds(limitDurationInSeconds)))) + .then(Mono.defer(() -> rateLimit.tryAcquire(events))) + .map(Boolean.FALSE::equals); } @Override public Mono decrement(@NonNull String apiKey, @NonNull String bucketName, long events) { - Mono eval = redisClient.getScript(StringCodec.INSTANCE).eval( - RScript.Mode.READ_WRITE, - LUA_SCRIPT_DECR, - RScript.ReturnType.VALUE, - List.of(bucketName + ":" + apiKey), - events); + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.tryAcquire(-events).then(); + } + + @Override + public Mono availableEvents(@NonNull String apiKey, @NonNull String bucketName) { + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.availablePermits(); + } - return eval.map("OK"::equals) - .switchIfEmpty(Mono.error(new IllegalStateException("Rate limit bucket not found"))) - .then(); + @Override + public Mono getRemainingTTL(@NonNull String apiKey, @NonNull String bucketName) { + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.remainTimeToLive(); } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java index e6f90f921d..aa10a8b2bf 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -31,7 +31,6 @@ import jakarta.ws.rs.core.Response; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -48,6 +47,7 @@ import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; +import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Set; @@ -64,6 +64,8 @@ import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; import static java.util.stream.Collectors.counting; import static java.util.stream.Collectors.groupingBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; @Testcontainers(parallel = true) @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -156,8 +158,8 @@ void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { Map responseMap = triggerCallsWithApiKey(LIMIT * 2, projectName, apiKey, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(429)); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(429)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -169,12 +171,12 @@ void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { .get()) { // Verify that traces created are equal to the limit - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT, page.content().size()); - Assertions.assertEquals(LIMIT, page.total()); - Assertions.assertEquals(LIMIT, page.size()); + assertEquals(LIMIT, page.content().size()); + assertEquals(LIMIT, page.total()); + assertEquals(LIMIT, page.size()); } } @@ -194,13 +196,13 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC Map responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -211,12 +213,12 @@ void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllC .header(WORKSPACE_HEADER, workspaceName) .get()) { - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT * 2, page.content().size()); - Assertions.assertEquals(LIMIT * 2, page.total()); - Assertions.assertEquals(LIMIT * 2, page.size()); + assertEquals(LIMIT * 2, page.content().size()); + assertEquals(LIMIT * 2, page.total()); + assertEquals(LIMIT * 2, page.size()); } } @@ -236,8 +238,8 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal Map responseMap = triggerCallsWithCookie(LIMIT * 2, projectName, sessionToken, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(429)); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(429)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -248,12 +250,12 @@ void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCal .header(WORKSPACE_HEADER, workspaceName) .get()) { - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT, page.content().size()); - Assertions.assertEquals(LIMIT, page.total()); - Assertions.assertEquals(LIMIT, page.size()); + assertEquals(LIMIT, page.content().size()); + assertEquals(LIMIT, page.total()); + assertEquals(LIMIT, page.size()); } } @@ -273,13 +275,13 @@ void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAll Map responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); - Assertions.assertEquals(LIMIT, responseMap.get(201)); + assertEquals(LIMIT, responseMap.get(201)); try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) .queryParam("project_name", projectName) @@ -291,18 +293,18 @@ void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAll .get()) { // Verify that traces created are equal to the limit - Assertions.assertEquals(200, response.getStatus()); + assertEquals(200, response.getStatus()); TracePage page = response.readEntity(TracePage.class); - Assertions.assertEquals(LIMIT * 2, page.content().size()); - Assertions.assertEquals(LIMIT * 2, page.total()); - Assertions.assertEquals(LIMIT * 2, page.size()); + assertEquals(LIMIT * 2, page.content().size()); + assertEquals(LIMIT * 2, page.total()); + assertEquals(LIMIT * 2, page.size()); } } @Test - @DisplayName("Rate limit: When remaining limit is less than the batch size; Then reject the request") + @DisplayName("Rate limit: When remaining limit is less than the batch size, Then reject the request") void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest() { String apiKey = UUID.randomUUID().toString(); @@ -316,7 +318,7 @@ void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest( Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); - Assertions.assertEquals(1, responseMap.get(201)); + assertEquals(1, responseMap.get(201)); List traces = IntStream.range(0, (int) LIMIT) .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() @@ -333,14 +335,14 @@ void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest( .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(new TraceBatch(traces)))) { - Assertions.assertEquals(429, response.getStatus()); + assertEquals(429, response.getStatus()); var error = response.readEntity(ErrorMessage.class); - Assertions.assertEquals("Too Many Requests", error.getMessage()); + assertEquals("Too Many Requests", error.getMessage()); } } @Test - @DisplayName("Rate limit: When after reject request due to batch size; Then accept the request with remaining limit") + @DisplayName("Rate limit: When after reject request due to batch size, Then accept the request with remaining limit") void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRemainingLimit() { String apiKey = UUID.randomUUID().toString(); @@ -354,7 +356,7 @@ void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRe Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); - Assertions.assertEquals(1, responseMap.get(201)); + assertEquals(1, responseMap.get(201)); List traces = IntStream.range(0, (int) LIMIT) .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() @@ -371,9 +373,9 @@ void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRe .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(new TraceBatch(traces)))) { - Assertions.assertEquals(429, response.getStatus()); + assertEquals(429, response.getStatus()); var error = response.readEntity(ErrorMessage.class); - Assertions.assertEquals("Too Many Requests", error.getMessage()); + assertEquals("Too Many Requests", error.getMessage()); } try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) @@ -384,13 +386,13 @@ void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRe .header(WORKSPACE_HEADER, workspaceName) .post(Entity.json(new TraceBatch(traces.subList(0, (int) LIMIT - 1))))) { - Assertions.assertEquals(204, response.getStatus()); + assertEquals(204, response.getStatus()); } } @ParameterizedTest @MethodSource - @DisplayName("Rate limit: When batch endpoint consumer remaining limit; Then reject next request") + @DisplayName("Rate limit: When batch endpoint consumer remaining limit, Then reject next request") void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( Object batch, Object batch2, @@ -412,17 +414,114 @@ void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( try (var response = request.method(method, Entity.json(batch))) { - Assertions.assertEquals(204, response.getStatus()); + assertEquals(204, response.getStatus()); } try (var response = request.method(method, Entity.json(batch2))) { - Assertions.assertEquals(429, response.getStatus()); + assertEquals(429, response.getStatus()); var error = response.readEntity(ErrorMessage.class); - Assertions.assertEquals("Too Many Requests", error.getMessage()); + assertEquals("Too Many Requests", error.getMessage()); } } + @Test + @DisplayName("Rate limit: When operation fails after accepting request; Then decrement the limit") + void rateLimit__whenOperationFailsAfterAcceptingRequest__thenDecrementTheLimit() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + // consume 1 from the limit + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + assertEquals(201, response.getStatus()); + } + + // consumer limit - 2 from the limit leaving 1 remaining + Map responseMap = triggerCallsWithApiKey(LIMIT - 2, projectName, apiKey, workspaceName); + + assertEquals(LIMIT - 2, responseMap.get(201)); + + // consume the remaining limit but fail + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + assertEquals(409, response.getStatus()); + } + + // consume the remaining limit + responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + assertEquals(1, responseMap.get(201)); + + // verify that the limit is now 0 + responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + assertEquals(1, responseMap.get(429)); + } + + @Test + @DisplayName("Rate limit: When processing operations, Then return remaining limit as header") + void rateLimit__whenProcessingOperations__thenReturnRemainingLimitAsHeader() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + IntStream.range(0, (int) LIMIT + 1).forEach(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + if (i < LIMIT) { + assertEquals(201, response.getStatus()); + + String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT); + String userLimit = response.getHeaderString(RequestContext.USER_LIMIT); + String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL); + + assertEquals(LIMIT - i - 1, Long.parseLong(remainingLimit)); + assertEquals(RateLimited.GENERAL_EVENTS, userLimit); + assertThat(Long.parseLong(remainingTtl)).isStrictlyBetween(0L, Duration.ofSeconds(LIMIT_DURATION_IN_SECONDS).toMillis()); + } else { + assertEquals(429, response.getStatus()); + } + } + }); + } + public Stream rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest() { var projectName = UUID.randomUUID().toString(); From 15ce88f4215bb96de976445f29ea3bd1088e21e8 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Fri, 20 Sep 2024 12:24:07 +0200 Subject: [PATCH 5/6] Address PR feedback --- apps/opik-backend/config.yml | 4 +- .../opik/infrastructure/RateLimitConfig.java | 8 +- .../ratelimit/RateLimitInterceptor.java | 69 +++++------ .../ratelimit/RateLimitService.java | 2 - .../redis/RedisRateLimitService.java | 12 +- .../TestDropwizardAppExtensionUtils.java | 38 +++++- .../ratelimit/RateLimitE2ETest.java | 108 +++++++++--------- .../src/test/resources/config-test.yml | 2 +- 8 files changed, 130 insertions(+), 113 deletions(-) diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index 0fce155dce..634766383c 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -69,5 +69,5 @@ server: rateLimit: enabled: ${RATE_LIMIT_ENABLED:-false} generalEvents: - limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-5000} - durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION_IN_SEC:-1} + limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-10000} + durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION_IN_SEC:-60} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java index dbaaa04ca9..11509ac4fe 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java @@ -6,6 +6,8 @@ import jakarta.validation.constraints.PositiveOrZero; import lombok.Data; +import java.util.Map; + @Data public class RateLimitConfig { @@ -19,6 +21,10 @@ public record LimitConfig(@Valid @JsonProperty @PositiveOrZero long limit, @Valid @JsonProperty - private LimitConfig generalEvents; + private LimitConfig generalLimit; + + @Valid + @JsonProperty + private Map customLimits; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java index f7c7364f60..eeb0cba777 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -13,6 +13,9 @@ import java.lang.reflect.Method; import java.util.List; +import java.util.Optional; + +import static com.comet.opik.infrastructure.RateLimitConfig.LimitConfig; @Slf4j @RequiredArgsConstructor @@ -28,12 +31,12 @@ public Object invoke(MethodInvocation invocation) throws Throwable { // Get the method being invoked Method method = invocation.getMethod(); - // Check if the method is annotated with @RateLimit - if (!method.isAnnotationPresent(RateLimited.class)) { + if (!rateLimitConfig.isEnabled()) { return invocation.proceed(); } - if (!rateLimitConfig.isEnabled()) { + // Check if the method is annotated with @RateLimit + if (!method.isAnnotationPresent(RateLimited.class)) { return invocation.proceed(); } @@ -41,40 +44,39 @@ public Object invoke(MethodInvocation invocation) throws Throwable { String bucket = rateLimit.value(); // Check if the bucket is the general events bucket - if (bucket.equals(RateLimited.GENERAL_EVENTS)) { - - Object body = getParameters(invocation); - long events = body instanceof RateEventContainer container ? container.eventCount() : 1; + LimitConfig generalLimit = Optional.ofNullable(rateLimitConfig.getCustomLimits()) + .map(limits -> limits.get(bucket)) + .orElse(rateLimitConfig.getGeneralLimit()); - long limit = rateLimitConfig.getGeneralEvents().limit(); - long limitDurationInSeconds = rateLimitConfig.getGeneralEvents().durationInSeconds(); - String apiKey = requestContext.get().getApiKey(); + String apiKey = requestContext.get().getApiKey(); + Object body = getParameters(invocation); - // Check if the rate limit is exceeded - Boolean limitExceeded = rateLimitService.get() - .isLimitExceeded(apiKey, events, bucket, limit, limitDurationInSeconds) - .block(); + long events = body instanceof RateEventContainer container ? container.eventCount() : 1; - if (Boolean.TRUE.equals(limitExceeded)) { - setLimitHeaders(apiKey, bucket); - throw new ClientErrorException("Too Many Requests", HttpStatus.SC_TOO_MANY_REQUESTS); - } + verifyRateLimit(events, apiKey, bucket, generalLimit); - try { - return invocation.proceed(); - } catch (Exception ex) { - decreaseLimitInCaseOfError(bucket, events); - throw ex; - } finally { - setLimitHeaders(apiKey, bucket); - } + try { + return invocation.proceed(); + } finally { + setLimitHeaders(apiKey, bucket); } + } + + private void verifyRateLimit(long events, String apiKey, String bucket, LimitConfig limitConfig) { - return invocation.proceed(); + // Check if the rate limit is exceeded + Boolean limitExceeded = rateLimitService.get() + .isLimitExceeded(apiKey, events, bucket, limitConfig.limit(), limitConfig.durationInSeconds()) + .block(); + + if (Boolean.TRUE.equals(limitExceeded)) { + setLimitHeaders(apiKey, bucket); + throw new ClientErrorException("Too Many Requests", HttpStatus.SC_TOO_MANY_REQUESTS); + } } private void setLimitHeaders(String apiKey, String bucket) { - requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(RateLimited.GENERAL_EVENTS)); + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(bucket)); requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block())); requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block())); } @@ -89,15 +91,4 @@ private Object getParameters(MethodInvocation method) { return null; } - - private void decreaseLimitInCaseOfError(String bucket, Long events) { - try { - String apiKey = requestContext.get().getApiKey(); - rateLimitService.get().decrement(apiKey, bucket, events) - .subscribe(); - } catch (Exception ex) { - log.warn("Failed to decrement rate limit", ex); - } - } - } \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java index 1b9cbb9014..68940bcae8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -7,8 +7,6 @@ public interface RateLimitService { Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, long limitDurationInSeconds); - Mono decrement(String apiKey, String bucketName, long events); - Mono availableEvents(String apiKey, String bucketName); Mono getRemainingTTL(String apiKey, String bucket); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java index 30f9dbaca9..5d78439909 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -2,6 +2,7 @@ import com.comet.opik.infrastructure.ratelimit.RateLimitService; import lombok.NonNull; +import lombok.RequiredArgsConstructor; import org.redisson.api.RRateLimiterReactive; import org.redisson.api.RateIntervalUnit; import org.redisson.api.RateType; @@ -10,16 +11,13 @@ import java.time.Duration; +@RequiredArgsConstructor public class RedisRateLimitService implements RateLimitService { private static final String KEY = "%s:%s"; private final RedissonReactiveClient redisClient; - public RedisRateLimitService(RedissonReactiveClient redisClient) { - this.redisClient = redisClient; - } - @Override public Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, long limitDurationInSeconds) { @@ -32,12 +30,6 @@ public Mono isLimitExceeded(String apiKey, long events, String bucketNa .map(Boolean.FALSE::equals); } - @Override - public Mono decrement(@NonNull String apiKey, @NonNull String bucketName, long events) { - RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); - return rateLimit.tryAcquire(-events).then(); - } - @Override public Mono availableEvents(@NonNull String apiKey, @NonNull String bucketName) { RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java index ba7b3f7539..3946e8298f 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java @@ -6,10 +6,17 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import lombok.Builder; import lombok.experimental.UtilityClass; +import org.apache.commons.collections4.CollectionUtils; import ru.vyarus.dropwizard.guice.hook.GuiceyConfigurationHook; +import ru.vyarus.dropwizard.guice.module.installer.bundle.GuiceyBundle; +import ru.vyarus.dropwizard.guice.module.installer.bundle.GuiceyEnvironment; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static com.comet.opik.infrastructure.RateLimitConfig.LimitConfig; @UtilityClass public class TestDropwizardAppExtensionUtils { @@ -23,7 +30,10 @@ public record AppContextConfig( Integer cacheTtlInSeconds, boolean rateLimitEnabled, Long limit, - Long limitDurationInSeconds) { + Long limitDurationInSeconds, + Map customLimits, + List customBeans + ) { } public static TestDropwizardAppExtension newTestDropwizardAppExtension(String jdbcUrl, @@ -90,6 +100,19 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension(AppContex GuiceyConfigurationHook hook = injector -> { injector.modulesOverride(TestHttpClientUtils.testAuthModule()); + + injector.bundles(new GuiceyBundle() { + + @Override + public void run(GuiceyEnvironment environment) { + + if (CollectionUtils.isNotEmpty(appContextConfig.customBeans())) { + appContextConfig.customBeans() + .forEach(environment::register); + } + } + }); + }; if (appContextConfig.redisUrl() != null) { @@ -100,9 +123,18 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension(AppContex if (appContextConfig.rateLimitEnabled()) { list.add("rateLimit.enabled: true"); - list.add("rateLimit.generalEvents.limit: %d".formatted(appContextConfig.limit())); - list.add("rateLimit.generalEvents.durationInSeconds: %d" + list.add("rateLimit.generalLimit.limit: %d".formatted(appContextConfig.limit())); + list.add("rateLimit.generalLimit.durationInSeconds: %d" .formatted(appContextConfig.limitDurationInSeconds())); + + if (appContextConfig.customLimits() != null) { + appContextConfig.customLimits() + .forEach((bucket, limitConfig) -> { + list.add("rateLimit.customLimits.%s.limit: %d".formatted(bucket, limitConfig.limit())); + list.add("rateLimit.customLimits.%s.durationInSeconds: %d".formatted(bucket, + limitConfig.durationInSeconds())); + }); + } } return TestDropwizardAppExtension.forApp(OpikApplication.class) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java index f18b807b2c..5adca3a662 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -23,7 +23,12 @@ import com.redis.testcontainers.RedisContainer; import io.dropwizard.jersey.errors.ErrorMessage; import io.reactivex.rxjava3.internal.operators.single.SingleDelay; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import jakarta.ws.rs.Consumes; import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.client.Invocation; import jakarta.ws.rs.core.HttpHeaders; @@ -62,6 +67,7 @@ import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; import static com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils.AppContextConfig; +import static com.comet.opik.infrastructure.RateLimitConfig.LimitConfig; import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; import static java.util.stream.Collectors.counting; import static java.util.stream.Collectors.groupingBy; @@ -85,9 +91,22 @@ class RateLimitE2ETest { private static final long LIMIT = 4L; private static final long LIMIT_DURATION_IN_SECONDS = 1L; + public static final String CUSTOM_LIMIT = "customLimit"; private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + @Path("/v1/private/test") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public static class CustomRatedBean { + + @POST + @RateLimited(value = CUSTOM_LIMIT) + public Response test(@RequestBody String test) { + return Response.status(Response.Status.CREATED).build(); + } + } + static { MYSQL.start(); CLICKHOUSE.start(); @@ -107,6 +126,7 @@ class RateLimitE2ETest { .rateLimitEnabled(true) .limit(LIMIT) .limitDurationInSeconds(LIMIT_DURATION_IN_SECONDS) + .customLimits(Map.of(CUSTOM_LIMIT, new LimitConfig(1, 1))) .build()); } @@ -426,61 +446,6 @@ void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( } } - @Test - @DisplayName("Rate limit: When operation fails after accepting request; Then decrement the limit") - void rateLimit__whenOperationFailsAfterAcceptingRequest__thenDecrementTheLimit() { - - String apiKey = UUID.randomUUID().toString(); - String user = UUID.randomUUID().toString(); - String workspaceId = UUID.randomUUID().toString(); - String workspaceName = UUID.randomUUID().toString(); - - mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); - - String projectName = UUID.randomUUID().toString(); - - Trace trace = factory.manufacturePojo(Trace.class).toBuilder() - .projectName(projectName) - .build(); - - // consume 1 from the limit - try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) - .request() - .accept(MediaType.APPLICATION_JSON_TYPE) - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(trace))) { - - assertEquals(HttpStatus.SC_CREATED, response.getStatus()); - } - - // consumer limit - 2 from the limit leaving 1 remaining - Map responseMap = triggerCallsWithApiKey(LIMIT - 2, projectName, apiKey, workspaceName); - - assertEquals(LIMIT - 2, responseMap.get(HttpStatus.SC_CREATED)); - - // consume the remaining limit but fail - try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) - .request() - .accept(MediaType.APPLICATION_JSON_TYPE) - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(trace))) { - - assertEquals(HttpStatus.SC_CONFLICT, response.getStatus()); - } - - // consume the remaining limit - responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); - - assertEquals(1, responseMap.get(HttpStatus.SC_CREATED)); - - // verify that the limit is now 0 - responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); - - assertEquals(1, responseMap.get(HttpStatus.SC_TOO_MANY_REQUESTS)); - } - @Test @DisplayName("Rate limit: When processing operations, Then return remaining limit as header") void rateLimit__whenProcessingOperations__thenReturnRemainingLimitAsHeader() { @@ -585,6 +550,39 @@ public Stream rateLimit__whenBatchEndpointConsumerRemainingLimit__the "%s/v1/private/experiments".formatted(baseURI) + "/items", HttpMethod.POST)); } + @Test + @DisplayName("Rate limit: When custom rated bean method is called, Then rate limit is applied") + void rateLimit__whenCustomRatedBeanMethodIsCalled__thenRateLimitIsApplied() { + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + try (var response = client.target("%s/v1/private/test".formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(""))) { + + assertEquals(HttpStatus.SC_CREATED, response.getStatus()); + } + + try (var response = client.target("%s/v1/private/test".formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(""))) { + + assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + } + + } + + private Map triggerCallsWithCookie(long limit, String projectName, String sessionToken, String workspaceName) { return Flux.range(0, ((int) limit)) diff --git a/apps/opik-backend/src/test/resources/config-test.yml b/apps/opik-backend/src/test/resources/config-test.yml index b81365d436..19dddb1d7f 100644 --- a/apps/opik-backend/src/test/resources/config-test.yml +++ b/apps/opik-backend/src/test/resources/config-test.yml @@ -67,4 +67,4 @@ server: enabled: true rateLimit: - enabled: false \ No newline at end of file + enabled: false From 004a8114eba7245ec4ecd0fe60a92908d9279da4 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Fri, 20 Sep 2024 12:41:53 +0200 Subject: [PATCH 6/6] Address PR feedback --- .../ratelimit/RateLimitInterceptor.java | 21 +++++++++++------- .../ratelimit/RateLimitResponseFilter.java | 3 ++- .../TestDropwizardAppExtensionUtils.java | 3 +-- .../ratelimit/RateLimitE2ETest.java | 22 +++++++++++++------ 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java index eeb0cba777..aef9e154dc 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -41,11 +41,14 @@ public Object invoke(MethodInvocation invocation) throws Throwable { } RateLimited rateLimit = method.getAnnotation(RateLimited.class); - String bucket = rateLimit.value(); - // Check if the bucket is the general events bucket - LimitConfig generalLimit = Optional.ofNullable(rateLimitConfig.getCustomLimits()) - .map(limits -> limits.get(bucket)) + // Check events bucket + Optional limitConfig = Optional.ofNullable(rateLimitConfig.getCustomLimits()) + .map(limits -> limits.get(rateLimit.value())); + + String limitBucket = limitConfig.isPresent() ? rateLimit.value() : RateLimited.GENERAL_EVENTS; + + LimitConfig generalLimit = limitConfig .orElse(rateLimitConfig.getGeneralLimit()); String apiKey = requestContext.get().getApiKey(); @@ -53,12 +56,12 @@ public Object invoke(MethodInvocation invocation) throws Throwable { long events = body instanceof RateEventContainer container ? container.eventCount() : 1; - verifyRateLimit(events, apiKey, bucket, generalLimit); + verifyRateLimit(events, apiKey, limitBucket, generalLimit); try { return invocation.proceed(); } finally { - setLimitHeaders(apiKey, bucket); + setLimitHeaders(apiKey, limitBucket); } } @@ -77,8 +80,10 @@ private void verifyRateLimit(long events, String apiKey, String bucket, LimitCon private void setLimitHeaders(String apiKey, String bucket) { requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(bucket)); - requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block())); - requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block())); + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, + List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block())); + requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, + List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block())); } private Object getParameters(MethodInvocation method) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java index 186a5bb432..2f1e4cc264 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java @@ -16,7 +16,8 @@ public class RateLimitResponseFilter implements ContainerResponseFilter { @Override - public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) throws IOException { + public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) + throws IOException { List userLimit = getValueFromHeader(requestContext, RequestContext.USER_LIMIT); List remainingLimit = getValueFromHeader(requestContext, RequestContext.USER_REMAINING_LIMIT); List remainingTtl = getValueFromHeader(requestContext, RequestContext.USER_LIMIT_REMAINING_TTL); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java index 3946e8298f..10207eb88e 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java @@ -32,8 +32,7 @@ public record AppContextConfig( Long limit, Long limitDurationInSeconds, Map customLimits, - List customBeans - ) { + List customBeans) { } public static TestDropwizardAppExtension newTestDropwizardAppExtension(String jdbcUrl, diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java index 5adca3a662..b889d7b2c8 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -474,13 +474,8 @@ void rateLimit__whenProcessingOperations__thenReturnRemainingLimitAsHeader() { if (i < LIMIT) { assertEquals(HttpStatus.SC_CREATED, response.getStatus()); - String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT); - String userLimit = response.getHeaderString(RequestContext.USER_LIMIT); - String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL); - - assertEquals(LIMIT - i - 1, Long.parseLong(remainingLimit)); - assertEquals(RateLimited.GENERAL_EVENTS, userLimit); - assertThat(Long.parseLong(remainingTtl)).isStrictlyBetween(0L, Duration.ofSeconds(LIMIT_DURATION_IN_SECONDS).toMillis()); + assertLimitHeaders(response, LIMIT - i - 1, RateLimited.GENERAL_EVENTS, + (int) LIMIT_DURATION_IN_SECONDS); } else { assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); } @@ -568,6 +563,8 @@ void rateLimit__whenCustomRatedBeanMethodIsCalled__thenRateLimitIsApplied() { .post(Entity.json(""))) { assertEquals(HttpStatus.SC_CREATED, response.getStatus()); + + assertLimitHeaders(response, 0, CUSTOM_LIMIT, 1); } try (var response = client.target("%s/v1/private/test".formatted(baseURI)) @@ -578,10 +575,21 @@ void rateLimit__whenCustomRatedBeanMethodIsCalled__thenRateLimitIsApplied() { .post(Entity.json(""))) { assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + + assertLimitHeaders(response, 0, CUSTOM_LIMIT, 1); } } + private static void assertLimitHeaders(Response response, long expected, String limitBucket, int limitDuration) { + String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT); + String userLimit = response.getHeaderString(RequestContext.USER_LIMIT); + String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL); + + assertEquals(expected, Long.parseLong(remainingLimit)); + assertEquals(limitBucket, userLimit); + assertThat(Long.parseLong(remainingTtl)).isBetween(0L, Duration.ofSeconds(limitDuration).toMillis()); + } private Map triggerCallsWithCookie(long limit, String projectName, String sessionToken, String workspaceName) {