diff --git a/.travis.yml b/.travis.yml index a40bdf55e..4722957c8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,7 @@ matrix: - jdk: openjdk8 - jdk: openjdk11 env: SKIP_RELEASE=true - - jdk: openjdk13 + - jdk: openjdk14 env: SKIP_RELEASE=true env: diff --git a/README.md b/README.md index 173c3e1ad..538587154 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,26 @@ Releases are available via Maven Central. Example: ```groovy +repositories { + mavenCentral() +} +dependencies { + implementation 'io.rsocket:rsocket-core:1.0.0-RC6' + implementation 'io.rsocket:rsocket-transport-netty:1.0.0-RC6' +} +``` + +Snapshots are available via [oss.jfrog.org](oss.jfrog.org) (OJO). + +Example: + +```groovy +repositories { + maven { url 'https://oss.jfrog.org/oss-snapshot-local' } +} dependencies { - implementation 'io.rsocket:rsocket-core:1.0.0-RC3' - implementation 'io.rsocket:rsocket-transport-netty:1.0.0-RC3' -// implementation 'io.rsocket:rsocket-core:1.0.0-RC4-SNAPSHOT' -// implementation 'io.rsocket:rsocket-transport-netty:1.0.0-RC4-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.0.0-RC7-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.0.0-RC7-SNAPSHOT' } ``` @@ -57,7 +72,7 @@ package io.rsocket.transport.netty; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.DefaultPayload; import reactor.core.publisher.Flux; @@ -67,14 +82,14 @@ import java.net.URI; public class ExampleClient { public static void main(String[] args) { WebsocketClientTransport ws = WebsocketClientTransport.create(URI.create("ws://rsocket-demo.herokuapp.com/ws")); - RSocket client = RSocketFactory.connect().keepAlive().transport(ws).start().block(); + RSocket clientRSocket = RSocketConnector.connectWith(ws).block(); try { - Flux s = client.requestStream(DefaultPayload.create("peace")); + Flux s = clientRSocket.requestStream(DefaultPayload.create("peace")); s.take(10).doOnNext(p -> System.out.println(p.getDataUtf8())).blockLast(); } finally { - client.dispose(); + clientRSocket.dispose(); } } } @@ -89,12 +104,10 @@ or you will get a memory leak. Used correctly this will reduce latency and incre ### Example Server setup ```java -RSocketFactory.receive() +RSocketServer.create(new PingHandler()) // Enable Zero Copy - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create(7878)) - .start() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(7878)) .block() .onClose() .block(); @@ -102,12 +115,13 @@ RSocketFactory.receive() ### Example Client setup ```java -Mono client = - RSocketFactory.connect() +RSocket clientRSocket = + RSocketConnector.create() // Enable Zero Copy - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(TcpClientTransport.create(7878)) - .start(); + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(TcpClientTransport.create(7878)) + .start() + .block(); ``` ## Bugs and Feedback diff --git a/build.gradle b/build.gradle index e834e21bd..2c7757e0f 100644 --- a/build.gradle +++ b/build.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,12 +26,12 @@ plugins { subprojects { apply plugin: 'io.spring.dependency-management' apply plugin: 'com.github.sherter.google-java-format' - - ext['reactor-bom.version'] = 'Dysprosium-RELEASE' + + ext['reactor-bom.version'] = 'Dysprosium-SR7' ext['logback.version'] = '1.2.3' ext['findbugs.version'] = '3.0.2' - ext['netty-bom.version'] = '4.1.37.Final' - ext['netty-boringssl.version'] = '2.0.25.Final' + ext['netty-bom.version'] = '4.1.48.Final' + ext['netty-boringssl.version'] = '2.0.30.Final' ext['hdrhistogram.version'] = '2.1.10' ext['mockito.version'] = '3.2.0' ext['slf4j.version'] = '1.7.25' @@ -88,11 +88,18 @@ subprojects { repositories { mavenCentral() - if (version.endsWith('BUILD-SNAPSHOT') || project.hasProperty('platformVersion')) { + if (version.endsWith('SNAPSHOT') || project.hasProperty('platformVersion')) { maven { url 'http://repo.spring.io/libs-snapshot' } + maven { + url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' + } } } + tasks.withType(GenerateModuleMetadata) { + enabled = false + } + plugins.withType(JavaPlugin) { compileJava { sourceCompatibility = 1.8 @@ -102,21 +109,61 @@ subprojects { } javadoc { + def jdk = JavaVersion.current().majorVersion + def jdkJavadoc = "https://docs.oracle.com/javase/$jdk/docs/api/" + if (JavaVersion.current().isJava11Compatible()) { + jdkJavadoc = "https://docs.oracle.com/en/java/javase/$jdk/docs/api/" + } options.with { - links 'https://docs.oracle.com/javase/8/docs/api/' + links jdkJavadoc links 'https://projectreactor.io/docs/core/release/api/' links 'https://netty.io/4.1/api/' } } + tasks.named("javadoc").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + test { useJUnitPlatform() systemProperty "io.netty.leakDetection.level", "ADVANCED" } - tasks.named("javadoc").configure { - onlyIf { System.getenv('SKIP_RELEASE') != "true" } + //all test tasks will show FAILED for each test method, + // common exclusions, no scanning + project.tasks.withType(Test).all { + testLogging { + events "FAILED" + showExceptions true + exceptionFormat "FULL" + stackTraceFilters "ENTRY_POINT" + maxGranularity 3 + } + + if (JavaVersion.current().isJava9Compatible()) { + println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" + jvmArgs = ["-XX:MaxGCPauseMillis=20"] + } + + systemProperty("java.awt.headless", "true") + systemProperty("reactor.trace.cancel", "true") + systemProperty("reactor.trace.nocapacity", "true") + systemProperty("testGroups", project.properties.get("testGroups")) + scanForTestClasses = false + exclude '**/*Abstract*.*' + + //allow re-run of failed tests only without special test tasks failing + // because the filter is too restrictive + filter.setFailOnNoMatchingTests(false) + + //display intermediate results for special test tasks + afterSuite { desc, result -> + if (!desc.parent) { // will match the outermost suite + println('\n' + "${desc} Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)") + } + } } } diff --git a/gradle.properties b/gradle.properties index 13a89e30c..3018f4792 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,5 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=1.0.0-RC6 -perfBaselineVersion=1.0.0-RC5 +version=1.0.0-RC7 +perfBaselineVersion=1.0.0-RC6 diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 94920145f..a4b442974 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.0.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.3-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/rsocket-core/src/main/java/io/rsocket/Closeable.java b/rsocket-core/src/main/java/io/rsocket/Closeable.java index 5eb871e18..2ea9a0371 100644 --- a/rsocket-core/src/main/java/io/rsocket/Closeable.java +++ b/rsocket-core/src/main/java/io/rsocket/Closeable.java @@ -16,17 +16,21 @@ package io.rsocket; +import org.reactivestreams.Subscriber; import reactor.core.Disposable; import reactor.core.publisher.Mono; -/** */ +/** An interface which allows listening to when a specific instance of this interface is closed */ public interface Closeable extends Disposable { /** - * Returns a {@code Publisher} that completes when this {@code RSocket} is closed. A {@code - * RSocket} can be closed by explicitly calling {@link RSocket#dispose()} or when the underlying - * transport connection is closed. + * Returns a {@link Mono} that terminates when the instance is terminated by any reason. Note, in + * case of error termination, the cause of error will be propagated as an error signal through + * {@link org.reactivestreams.Subscriber#onError(Throwable)}. Otherwise, {@link + * Subscriber#onComplete()} will be called. * - * @return A {@code Publisher} that completes when this {@code RSocket} close is complete. + * @return a {@link Mono} to track completion with success or error of the underlying resource. + * When the underlying resource is an `RSocket`, the {@code Mono} exposes stream 0 (i.e. + * connection level) errors. */ Mono onClose(); } diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index 8762e0489..bd4582e2b 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -1,11 +1,11 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,28 +18,22 @@ import io.netty.buffer.ByteBuf; import io.netty.util.AbstractReferenceCounted; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.core.DefaultConnectionSetupPayload; import javax.annotation.Nullable; /** - * Exposed to server for determination of ResponderRSocket based on mime types and SETUP - * metadata/data + * Exposes information from the {@code SETUP} frame to a server, as well as to client responders. */ public abstract class ConnectionSetupPayload extends AbstractReferenceCounted implements Payload { - public static ConnectionSetupPayload create(final ByteBuf setupFrame) { - return new DefaultConnectionSetupPayload(setupFrame); - } + public abstract String metadataMimeType(); + + public abstract String dataMimeType(); public abstract int keepAliveInterval(); public abstract int keepAliveMaxLifetime(); - public abstract String metadataMimeType(); - - public abstract String dataMimeType(); - public abstract int getFlags(); public abstract boolean willClientHonorLease(); @@ -64,96 +58,15 @@ public ConnectionSetupPayload retain(int increment) { @Override public abstract ConnectionSetupPayload touch(); - @Override - public abstract ConnectionSetupPayload touch(Object hint); - - private static final class DefaultConnectionSetupPayload extends ConnectionSetupPayload { - private final ByteBuf setupFrame; - - public DefaultConnectionSetupPayload(ByteBuf setupFrame) { - this.setupFrame = setupFrame; - } - - @Override - public boolean hasMetadata() { - return FrameHeaderFlyweight.hasMetadata(setupFrame); - } - - @Override - public int keepAliveInterval() { - return SetupFrameFlyweight.keepAliveInterval(setupFrame); - } - - @Override - public int keepAliveMaxLifetime() { - return SetupFrameFlyweight.keepAliveMaxLifetime(setupFrame); - } - - @Override - public String metadataMimeType() { - return SetupFrameFlyweight.metadataMimeType(setupFrame); - } - - @Override - public String dataMimeType() { - return SetupFrameFlyweight.dataMimeType(setupFrame); - } - - @Override - public int getFlags() { - return FrameHeaderFlyweight.flags(setupFrame); - } - - @Override - public boolean willClientHonorLease() { - return SetupFrameFlyweight.honorLease(setupFrame); - } - - @Override - public boolean isResumeEnabled() { - return SetupFrameFlyweight.resumeEnabled(setupFrame); - } - - @Override - public ByteBuf resumeToken() { - return SetupFrameFlyweight.resumeToken(setupFrame); - } - - @Override - public ConnectionSetupPayload touch() { - setupFrame.touch(); - return this; - } - - @Override - public ConnectionSetupPayload touch(Object hint) { - setupFrame.touch(hint); - return this; - } - - @Override - protected void deallocate() { - setupFrame.release(); - } - - @Override - public ByteBuf sliceMetadata() { - return SetupFrameFlyweight.metadata(setupFrame); - } - - @Override - public ByteBuf sliceData() { - return SetupFrameFlyweight.data(setupFrame); - } - - @Override - public ByteBuf data() { - return sliceData(); - } - - @Override - public ByteBuf metadata() { - return sliceMetadata(); - } + /** + * Create a {@code ConnectionSetupPayload}. + * + * @deprecated as of 1.0 RC7. Please, use {@link + * DefaultConnectionSetupPayload#DefaultConnectionSetupPayload(ByteBuf) + * DefaultConnectionSetupPayload} constructor. + */ + @Deprecated + public static ConnectionSetupPayload create(final ByteBuf setupFrame) { + return new DefaultConnectionSetupPayload(setupFrame); } } diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index b87ed0570..6190d24e3 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.nio.channels.ClosedChannelException; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -78,6 +79,13 @@ default Mono sendOne(ByteBuf frame) { */ Flux receive(); + /** + * Returns the assigned {@link ByteBufAllocator}. + * + * @return the {@link ByteBufAllocator} + */ + ByteBufAllocator alloc(); + @Override default double availability() { return isDisposed() ? 0.0 : 1.0; diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java new file mode 100644 index 000000000..b43b14bae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import reactor.util.annotation.Nullable; + +/** + * Exception that represents an RSocket protocol error. + * + * @see ERROR + * Frame (0x0B) + */ +public class RSocketErrorException extends RuntimeException { + + private static final long serialVersionUID = -1628781753426267554L; + + private static final int MIN_ERROR_CODE = 0x00000001; + + private static final int MAX_ERROR_CODE = 0xFFFFFFFE; + + private final int errorCode; + + /** + * Constructor with a protocol error code and a message. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + */ + public RSocketErrorException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Alternative to {@link #RSocketErrorException(int, String)} with a root cause. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + * @param cause a root cause for the error + */ + public RSocketErrorException(int errorCode, String message, @Nullable Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + if (errorCode > MAX_ERROR_CODE && errorCode < MIN_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000001-0xFFFFFFFE]", this); + } + } + + /** + * Return the RSocket error code + * represented by this exception + * + * @return the RSocket protocol error code + */ + public int errorCode() { + return errorCode; + } + + @Override + public String toString() { + return getClass().getSimpleName() + + " (0x" + + Integer.toHexString(errorCode) + + "): " + + getMessage(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 44f64e550..178cc4fa9 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,57 +13,62 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.rsocket; -import static io.rsocket.internal.ClientSetup.DefaultClientSetup; -import static io.rsocket.internal.ClientSetup.ResumableClientSetup; - import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.InvalidSetupException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.ResumeFrameFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; +import io.netty.buffer.Unpooled; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; -import io.rsocket.internal.ClientSetup; -import io.rsocket.internal.ServerSetup; -import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.lease.LeaseStats; import io.rsocket.lease.Leases; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; -import io.rsocket.plugins.*; -import io.rsocket.resume.*; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.plugins.SocketAcceptorInterceptor; +import io.rsocket.resume.ClientResume; +import io.rsocket.resume.ResumableFramesStore; +import io.rsocket.resume.ResumeStrategy; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.util.ConnectionUtils; -import io.rsocket.util.EmptyPayload; -import io.rsocket.util.MultiSubscriberRSocket; import java.time.Duration; -import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * Main entry point to create RSocket clients or servers as follows: + * + *
    + *
  • {@link ClientRSocketFactory} to connect as a client. Use {@link #connect()} for a default + * instance. + *
  • {@link ServerRSocketFactory} to start a server. Use {@link #receive()} for a default + * instance. + *
+ * + * @deprecated please use {@link RSocketConnector} and {@link RSocketServer}. + */ +@Deprecated +public final class RSocketFactory { -/** Factory for creating RSocket clients and servers. */ -public class RSocketFactory { /** - * Creates a factory that establishes client connections to other RSockets. + * Create a {@code ClientRSocketFactory} to connect to a remote RSocket endpoint. Internally + * delegates to {@link RSocketConnector}. * - * @return a client factory + * @return the {@code ClientRSocketFactory} instance */ public static ClientRSocketFactory connect() { return new ClientRSocketFactory(); } /** - * Creates a factory that receives server connections from client RSockets. + * Create a {@code ServerRSocketFactory} to accept connections from RSocket clients. Internally + * delegates to {@link RSocketServer}. * - * @return a server factory. + * @return the {@code ClientRSocketFactory} instance */ public static ServerRSocketFactory receive() { return new ServerRSocketFactory(); @@ -92,52 +97,58 @@ default Start transport(ServerTransport transport) { } } + /** Factory to create and configure an RSocket client, and connect to a server. */ public static class ClientRSocketFactory implements ClientTransportAcceptor { - private static final String CLIENT_TAG = "client"; - - private SocketAcceptor acceptor = (setup, sendingSocket) -> Mono.just(new AbstractRSocket() {}); - - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); + private static final ClientResume CLIENT_RESUME = + new ClientResume(Duration.ofMinutes(2), Unpooled.EMPTY_BUFFER); - private Payload setupPayload = EmptyPayload.INSTANCE; - private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private final RSocketConnector connector; private Duration tickPeriod = Duration.ofSeconds(20); private Duration ackTimeout = Duration.ofSeconds(30); private int missedAcks = 3; - private String metadataMimeType = "application/binary"; - private String dataMimeType = "application/binary"; + private Resume resume; - private boolean resumeEnabled; - private boolean resumeCleanupStoreOnKeepAlive; - private Supplier resumeTokenSupplier = ResumeFrameFlyweight::generateResumeToken; - private Function resumeStoreFactory = - token -> new InMemoryResumableFramesStore(CLIENT_TAG, 100_000); - private Duration resumeSessionDuration = Duration.ofMinutes(2); - private Duration resumeStreamTimeout = Duration.ofSeconds(10); - private Supplier resumeStrategySupplier = - () -> - new ExponentialBackoffResumeStrategy(Duration.ofSeconds(1), Duration.ofSeconds(16), 2); - - private boolean multiSubscriberRequester = true; - private boolean leaseEnabled; - private Supplier> leasesSupplier = Leases::new; + public ClientRSocketFactory() { + this(RSocketConnector.create().errorConsumer(Throwable::printStackTrace)); + } - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + public ClientRSocketFactory(RSocketConnector connector) { + this.connector = connector; + } + /** + * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in + * order configure the custom {@link ByteBufAllocator} it is recommended to use the + * following setup for Reactor Netty based transport:
+ * 1. For Client:
+ *
{@code
+     * TcpClient.create()
+     *          ...
+     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
+     * }
+ *
+ * 2. For server:
+ *
{@code
+     * TcpServer.create()
+     *          ...
+     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
+     * }
+ * Or in case of local transport, to use corresponding factory method {@code + * LocalClientTransport.creat(String, ByteBufAllocator)} + * @param allocator instance of {@link ByteBufAllocator} + * @return this factory instance + */ public ClientRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - Objects.requireNonNull(allocator); - this.allocator = allocator; return this; } public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); + connector.interceptors(registry -> registry.forConnection(interceptor)); return this; } + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ @Deprecated public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { @@ -145,7 +156,7 @@ public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { } public ClientRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { - plugins.addRequesterPlugin(interceptor); + connector.interceptors(registry -> registry.forRequester(interceptor)); return this; } @@ -156,309 +167,292 @@ public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { } public ClientRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { - plugins.addResponderPlugin(interceptor); + connector.interceptors(registry -> registry.forResponder(interceptor)); return this; } public ClientRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - plugins.addSocketAcceptorPlugin(interceptor); + connector.interceptors(registry -> registry.forSocketAcceptor(interceptor)); return this; } /** - * Deprecated as Keep-Alive is not optional according to spec + * Deprecated without replacement as Keep-Alive is not optional according to spec * * @return this ClientRSocketFactory */ @Deprecated public ClientRSocketFactory keepAlive() { + connector.keepAlive(tickPeriod, ackTimeout.plus(tickPeriod.multipliedBy(missedAcks))); return this; } - public ClientRSocketFactory keepAlive( + public ClientTransportAcceptor keepAlive( Duration tickPeriod, Duration ackTimeout, int missedAcks) { this.tickPeriod = tickPeriod; this.ackTimeout = ackTimeout; this.missedAcks = missedAcks; + keepAlive(); return this; } public ClientRSocketFactory keepAliveTickPeriod(Duration tickPeriod) { this.tickPeriod = tickPeriod; + keepAlive(); return this; } public ClientRSocketFactory keepAliveAckTimeout(Duration ackTimeout) { this.ackTimeout = ackTimeout; + keepAlive(); return this; } public ClientRSocketFactory keepAliveMissedAcks(int missedAcks) { this.missedAcks = missedAcks; + keepAlive(); return this; } public ClientRSocketFactory mimeType(String metadataMimeType, String dataMimeType) { - this.dataMimeType = dataMimeType; - this.metadataMimeType = metadataMimeType; + connector.metadataMimeType(metadataMimeType); + connector.dataMimeType(dataMimeType); return this; } public ClientRSocketFactory dataMimeType(String dataMimeType) { - this.dataMimeType = dataMimeType; + connector.dataMimeType(dataMimeType); return this; } public ClientRSocketFactory metadataMimeType(String metadataMimeType) { - this.metadataMimeType = metadataMimeType; + connector.metadataMimeType(metadataMimeType); return this; } - public ClientRSocketFactory lease(Supplier> leasesSupplier) { - this.leaseEnabled = true; - this.leasesSupplier = Objects.requireNonNull(leasesSupplier); + public ClientRSocketFactory lease(Supplier> supplier) { + connector.lease(supplier); return this; } public ClientRSocketFactory lease() { - this.leaseEnabled = true; + connector.lease(Leases::new); return this; } + /** @deprecated without a replacement and no longer used. */ + @Deprecated public ClientRSocketFactory singleSubscriberRequester() { - this.multiSubscriberRequester = false; + return this; + } + + /** + * Enables a reconnectable, shared instance of {@code Mono} so every subscriber will + * observe the same RSocket instance up on connection establishment.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  RSocket r1 = sharedRSocketMono.block();
+     *  RSocket r2 = sharedRSocketMono.block();
+     *
+     *  assert r1 == r2;
+     *
+     * }
+ * + * Apart of the shared behavior, if the connection is lost, the same {@code Mono} + * instance will transparently re-establish the connection for subsequent subscribers.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  RSocket r1 = sharedRSocketMono.block();
+     *  RSocket r2 = sharedRSocketMono.block();
+     *
+     *  assert r1 == r2;
+     *
+     *  r1.dispose()
+     *
+     *  assert r2.isDisposed()
+     *
+     *  RSocket r3 = sharedRSocketMono.block();
+     *  RSocket r4 = sharedRSocketMono.block();
+     *
+     *
+     *  assert r1 != r3;
+     *  assert r4 == r3;
+     *
+     * }
+ * + * Note, having reconnect() enabled does not eliminate the need to accompany each + * individual request with the corresponding retry logic.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  sharedRSocket.flatMap(rSocket -> rSocket.requestResponse(...))
+     *               .retryWhen(ownRetry)
+     *               .subscribe()
+     *
+     * }
+ * + * @param retrySpec a retry factory applied for {@link Mono#retryWhen(Retry)} + * @return a shared instance of {@code Mono}. + */ + public ClientRSocketFactory reconnect(Retry retrySpec) { + connector.reconnect(retrySpec); return this; } public ClientRSocketFactory resume() { - this.resumeEnabled = true; + resume = resume != null ? resume : new Resume(); + connector.resume(resume); return this; } - public ClientRSocketFactory resumeToken(Supplier resumeTokenSupplier) { - this.resumeTokenSupplier = Objects.requireNonNull(resumeTokenSupplier); + public ClientRSocketFactory resumeToken(Supplier supplier) { + resume(); + resume.token(supplier); return this; } public ClientRSocketFactory resumeStore( - Function resumeStoreFactory) { - this.resumeStoreFactory = resumeStoreFactory; + Function storeFactory) { + resume(); + resume.storeFactory(storeFactory); return this; } public ClientRSocketFactory resumeSessionDuration(Duration sessionDuration) { - this.resumeSessionDuration = Objects.requireNonNull(sessionDuration); + resume(); + resume.sessionDuration(sessionDuration); return this; } - public ClientRSocketFactory resumeStreamTimeout(Duration resumeStreamTimeout) { - this.resumeStreamTimeout = Objects.requireNonNull(resumeStreamTimeout); + public ClientRSocketFactory resumeStreamTimeout(Duration streamTimeout) { + resume(); + resume.streamTimeout(streamTimeout); return this; } - public ClientRSocketFactory resumeStrategy(Supplier resumeStrategy) { - this.resumeStrategySupplier = Objects.requireNonNull(resumeStrategy); + public ClientRSocketFactory resumeStrategy(Supplier strategy) { + resume(); + resume.retry( + Retry.from( + signals -> signals.flatMap(s -> strategy.get().apply(CLIENT_RESUME, s.failure())))); return this; } public ClientRSocketFactory resumeCleanupOnKeepAlive() { - resumeCleanupStoreOnKeepAlive = true; + resume(); + resume.cleanupStoreOnKeepAlive(); return this; } - @Override - public Start transport(Supplier transportClient) { - return new StartClient(transportClient); + public Start transport(Supplier transport) { + return () -> connector.connect(transport); } public ClientTransportAcceptor acceptor(Function acceptor) { return acceptor(() -> acceptor); } - public ClientTransportAcceptor acceptor(Supplier> acceptor) { - return acceptor((setup, sendingSocket) -> Mono.just(acceptor.get().apply(sendingSocket))); + public ClientTransportAcceptor acceptor(Supplier> acceptorSupplier) { + return acceptor( + (setup, sendingSocket) -> { + acceptorSupplier.get().apply(sendingSocket); + return Mono.empty(); + }); } public ClientTransportAcceptor acceptor(SocketAcceptor acceptor) { - this.acceptor = acceptor; - return StartClient::new; + connector.acceptor(acceptor); + return this; } public ClientRSocketFactory fragment(int mtu) { - this.mtu = mtu; + connector.fragment(mtu); return this; } + /** @deprecated this is deprecated with no replacement. */ public ClientRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; + connector.errorConsumer(errorConsumer); return this; } public ClientRSocketFactory setupPayload(Payload payload) { - this.setupPayload = payload; + connector.setupPayload(payload); return this; } public ClientRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - this.payloadDecoder = payloadDecoder; - return this; - } - - private class StartClient implements Start { - private final Supplier transportClient; - - StartClient(Supplier transportClient) { - this.transportClient = transportClient; - } - - @Override - public Mono start() { - return newConnection() - .flatMap( - connection -> { - ClientSetup clientSetup = clientSetup(connection); - ByteBuf resumeToken = clientSetup.resumeToken(); - KeepAliveHandler keepAliveHandler = clientSetup.keepAliveHandler(); - DuplexConnection wrappedConnection = clientSetup.connection(); - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(wrappedConnection, plugins, true); - - boolean isLeaseEnabled = leaseEnabled; - Leases leases = leasesSupplier.get(); - RequesterLeaseHandler requesterLeaseHandler = - isLeaseEnabled - ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) - : RequesterLeaseHandler.None; - - RSocket rSocketRequester = - new RSocketRequester( - allocator, - multiplexer.asClientConnection(), - payloadDecoder, - errorConsumer, - StreamIdSupplier.clientSupplier(), - keepAliveTickPeriod(), - keepAliveTimeout(), - keepAliveHandler, - requesterLeaseHandler); - - if (multiSubscriberRequester) { - rSocketRequester = new MultiSubscriberRSocket(rSocketRequester); - } - - RSocket wrappedRSocketRequester = plugins.applyRequester(rSocketRequester); - - ByteBuf setupFrame = - SetupFrameFlyweight.encode( - allocator, - isLeaseEnabled, - keepAliveTickPeriod(), - keepAliveTimeout(), - resumeToken, - metadataMimeType, - dataMimeType, - setupPayload); - - ConnectionSetupPayload setup = ConnectionSetupPayload.create(setupFrame); - - return plugins - .applySocketAcceptorInterceptor(acceptor) - .accept(setup, wrappedRSocketRequester) - .flatMap( - rSocketHandler -> { - RSocket wrappedRSocketHandler = plugins.applyResponder(rSocketHandler); - - ResponderLeaseHandler responderLeaseHandler = - isLeaseEnabled - ? new ResponderLeaseHandler.Impl<>( - CLIENT_TAG, - allocator, - leases.sender(), - errorConsumer, - leases.stats()) - : ResponderLeaseHandler.None; - - RSocket rSocketResponder = - new RSocketResponder( - allocator, - multiplexer.asServerConnection(), - wrappedRSocketHandler, - payloadDecoder, - errorConsumer, - responderLeaseHandler); - - return wrappedConnection - .sendOne(setupFrame) - .thenReturn(wrappedRSocketRequester); - }); - }); - } - - private int keepAliveTickPeriod() { - return (int) tickPeriod.toMillis(); - } - - private int keepAliveTimeout() { - return (int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks); - } - - private ClientSetup clientSetup(DuplexConnection startConnection) { - if (resumeEnabled) { - ByteBuf resumeToken = resumeTokenSupplier.get(); - return new ResumableClientSetup( - allocator, - startConnection, - newConnection(), - resumeToken, - resumeStoreFactory.apply(resumeToken), - resumeSessionDuration, - resumeStreamTimeout, - resumeStrategySupplier, - resumeCleanupStoreOnKeepAlive); - } else { - return new DefaultClientSetup(startConnection); - } - } - - private Mono newConnection() { - return transportClient.get().connect(mtu); - } + connector.payloadDecoder(payloadDecoder); + return this; } } - public static class ServerRSocketFactory { - private static final String SERVER_TAG = "server"; - - private SocketAcceptor acceptor; - private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); - - private boolean resumeSupported; - private Duration resumeSessionDuration = Duration.ofSeconds(120); - private Duration resumeStreamTimeout = Duration.ofSeconds(10); - private Function resumeStoreFactory = - token -> new InMemoryResumableFramesStore(SERVER_TAG, 100_000); + /** Factory to create, configure, and start an RSocket server. */ + public static class ServerRSocketFactory implements ServerTransportAcceptor { + private final RSocketServer server; - private boolean multiSubscriberRequester = true; - private boolean leaseEnabled; - private Supplier> leasesSupplier = Leases::new; + private Resume resume; - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - private boolean resumeCleanupStoreOnKeepAlive; + public ServerRSocketFactory() { + this(RSocketServer.create().errorConsumer(Throwable::printStackTrace)); + } - private ServerRSocketFactory() {} + public ServerRSocketFactory(RSocketServer server) { + this.server = server; + } + /** + * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in + * order configure the custom {@link ByteBufAllocator} it is recommended to use the + * following setup for Reactor Netty based transport:
+ * 1. For Client:
+ *
{@code
+     * TcpClient.create()
+     *          ...
+     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
+     * }
+ *
+ * 2. For server:
+ *
{@code
+     * TcpServer.create()
+     *          ...
+     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
+     * }
+ * Or in case of local transport, to use corresponding factory method {@code + * LocalClientTransport.creat(String, ByteBufAllocator)} + * @param allocator instance of {@link ByteBufAllocator} + * @return this factory instance + */ + @Deprecated public ServerRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - Objects.requireNonNull(allocator); - this.allocator = allocator; return this; } public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); + server.interceptors(registry -> registry.forConnection(interceptor)); return this; } /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ @@ -468,7 +462,7 @@ public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { } public ServerRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { - plugins.addRequesterPlugin(interceptor); + server.interceptors(registry -> registry.forRequester(interceptor)); return this; } @@ -479,265 +473,91 @@ public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { } public ServerRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { - plugins.addResponderPlugin(interceptor); + server.interceptors(registry -> registry.forResponder(interceptor)); return this; } public ServerRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - plugins.addSocketAcceptorPlugin(interceptor); + server.interceptors(registry -> registry.forSocketAcceptor(interceptor)); return this; } public ServerTransportAcceptor acceptor(SocketAcceptor acceptor) { - this.acceptor = acceptor; - return new ServerStart<>(); + server.acceptor(acceptor); + return this; } public ServerRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - this.payloadDecoder = payloadDecoder; + server.payloadDecoder(payloadDecoder); return this; } public ServerRSocketFactory fragment(int mtu) { - this.mtu = mtu; + server.fragment(mtu); return this; } + /** @deprecated this is deprecated with no replacement. */ public ServerRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; + server.errorConsumer(errorConsumer); return this; } - public ServerRSocketFactory lease(Supplier> leasesSupplier) { - this.leaseEnabled = true; - this.leasesSupplier = Objects.requireNonNull(leasesSupplier); + public ServerRSocketFactory lease(Supplier> supplier) { + server.lease(supplier); return this; } public ServerRSocketFactory lease() { - this.leaseEnabled = true; + server.lease(Leases::new); return this; } + /** @deprecated without a replacement and no longer used. */ + @Deprecated public ServerRSocketFactory singleSubscriberRequester() { - this.multiSubscriberRequester = false; return this; } public ServerRSocketFactory resume() { - this.resumeSupported = true; + resume = resume != null ? resume : new Resume(); + server.resume(resume); return this; } public ServerRSocketFactory resumeStore( - Function resumeStoreFactory) { - this.resumeStoreFactory = resumeStoreFactory; + Function storeFactory) { + resume(); + resume.storeFactory(storeFactory); return this; } public ServerRSocketFactory resumeSessionDuration(Duration sessionDuration) { - this.resumeSessionDuration = Objects.requireNonNull(sessionDuration); + resume(); + resume.sessionDuration(sessionDuration); return this; } - public ServerRSocketFactory resumeStreamTimeout(Duration resumeStreamTimeout) { - this.resumeStreamTimeout = Objects.requireNonNull(resumeStreamTimeout); + public ServerRSocketFactory resumeStreamTimeout(Duration streamTimeout) { + resume(); + resume.streamTimeout(streamTimeout); return this; } public ServerRSocketFactory resumeCleanupOnKeepAlive() { - resumeCleanupStoreOnKeepAlive = true; - return this; - } - - private class ServerStart implements Start, ServerTransportAcceptor { - private Supplier> transportServer; - - @Override - public ServerTransport.ConnectionAcceptor toConnectionAcceptor() { - return new ServerTransport.ConnectionAcceptor() { - private final ServerSetup serverSetup = serverSetup(); - - @Override - public Mono apply(DuplexConnection connection) { - return acceptor(serverSetup, connection); - } - }; - } - - @Override - @SuppressWarnings("unchecked") - public Start transport(Supplier> transport) { - this.transportServer = (Supplier) transport; - return (Start) this::start; - } - - private Mono acceptor(ServerSetup serverSetup, DuplexConnection connection) { - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, plugins, false); - - return multiplexer - .asSetupConnection() - .receive() - .next() - .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer)); - } - - private Mono acceptResume( - ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { - return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); - } - - private Mono accept( - ServerSetup serverSetup, ByteBuf startFrame, ClientServerInputMultiplexer multiplexer) { - switch (FrameHeaderFlyweight.frameType(startFrame)) { - case SETUP: - return acceptSetup(serverSetup, startFrame, multiplexer); - case RESUME: - return acceptResume(serverSetup, startFrame, multiplexer); - default: - return acceptUnknown(startFrame, multiplexer); - } - } - - private Mono acceptSetup( - ServerSetup serverSetup, ByteBuf setupFrame, ClientServerInputMultiplexer multiplexer) { - - if (!SetupFrameFlyweight.isSupportedVersion(setupFrame)) { - return sendError( - multiplexer, - new InvalidSetupException( - "Unsupported version: " - + SetupFrameFlyweight.humanReadableVersion(setupFrame))) - .doFinally( - signalType -> { - setupFrame.release(); - multiplexer.dispose(); - }); - } - - boolean isLeaseEnabled = leaseEnabled; - - if (SetupFrameFlyweight.honorLease(setupFrame) && !isLeaseEnabled) { - return sendError(multiplexer, new InvalidSetupException("lease is not supported")) - .doFinally( - signalType -> { - setupFrame.release(); - multiplexer.dispose(); - }); - } - - return serverSetup.acceptRSocketSetup( - setupFrame, - multiplexer, - (keepAliveHandler, wrappedMultiplexer) -> { - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame); - - Leases leases = leasesSupplier.get(); - RequesterLeaseHandler requesterLeaseHandler = - isLeaseEnabled - ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) - : RequesterLeaseHandler.None; - - RSocket rSocketRequester = - new RSocketRequester( - allocator, - wrappedMultiplexer.asServerConnection(), - payloadDecoder, - errorConsumer, - StreamIdSupplier.serverSupplier(), - setupPayload.keepAliveInterval(), - setupPayload.keepAliveMaxLifetime(), - keepAliveHandler, - requesterLeaseHandler); - - if (multiSubscriberRequester) { - rSocketRequester = new MultiSubscriberRSocket(rSocketRequester); - } - RSocket wrappedRSocketRequester = plugins.applyRequester(rSocketRequester); - - return plugins - .applySocketAcceptorInterceptor(acceptor) - .accept(setupPayload, wrappedRSocketRequester) - .onErrorResume( - err -> sendError(multiplexer, rejectedSetupError(err)).then(Mono.error(err))) - .doOnNext( - rSocketHandler -> { - RSocket wrappedRSocketHandler = plugins.applyResponder(rSocketHandler); - - ResponderLeaseHandler responderLeaseHandler = - isLeaseEnabled - ? new ResponderLeaseHandler.Impl<>( - SERVER_TAG, - allocator, - leases.sender(), - errorConsumer, - leases.stats()) - : ResponderLeaseHandler.None; - - RSocket rSocketResponder = - new RSocketResponder( - allocator, - wrappedMultiplexer.asClientConnection(), - wrappedRSocketHandler, - payloadDecoder, - errorConsumer, - responderLeaseHandler); - }) - .doFinally(signalType -> setupPayload.release()) - .then(); - }); - } - - @Override - public Mono start() { - return Mono.defer( - new Supplier>() { - - ServerSetup serverSetup = serverSetup(); - - @Override - public Mono get() { - return transportServer - .get() - .start(duplexConnection -> acceptor(serverSetup, duplexConnection), mtu) - .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); - } - }); - } - - private ServerSetup serverSetup() { - return resumeSupported - ? new ServerSetup.ResumableServerSetup( - allocator, - new SessionManager(), - resumeSessionDuration, - resumeStreamTimeout, - resumeStoreFactory, - resumeCleanupStoreOnKeepAlive) - : new ServerSetup.DefaultServerSetup(allocator); - } - - private Mono acceptUnknown(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { - return sendError( - multiplexer, - new InvalidSetupException( - "invalid setup frame: " + FrameHeaderFlyweight.frameType(frame))) - .doFinally( - signalType -> { - frame.release(); - multiplexer.dispose(); - }); - } - - private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - return ConnectionUtils.sendError(allocator, multiplexer, exception); - } - - private Exception rejectedSetupError(Throwable err) { - String msg = err.getMessage(); - return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); - } + resume(); + resume.cleanupStoreOnKeepAlive(); + return this; + } + + @Override + public ServerTransport.ConnectionAcceptor toConnectionAcceptor() { + return server.asConnectionAcceptor(); + } + + @Override + public Start transport(Supplier> transport) { + return () -> server.bind(transport.get()); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java new file mode 100644 index 000000000..feeb5c481 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.SetupFrameFlyweight; + +/** + * Default implementation of {@link ConnectionSetupPayload}. Primarily for internal use within + * RSocket Java but may be created in an application, e.g. for testing purposes. + */ +public class DefaultConnectionSetupPayload extends ConnectionSetupPayload { + + private final ByteBuf setupFrame; + + public DefaultConnectionSetupPayload(ByteBuf setupFrame) { + this.setupFrame = setupFrame; + } + + @Override + public boolean hasMetadata() { + return FrameHeaderFlyweight.hasMetadata(setupFrame); + } + + @Override + public ByteBuf sliceMetadata() { + final ByteBuf metadata = SetupFrameFlyweight.metadata(setupFrame); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + return SetupFrameFlyweight.data(setupFrame); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public String metadataMimeType() { + return SetupFrameFlyweight.metadataMimeType(setupFrame); + } + + @Override + public String dataMimeType() { + return SetupFrameFlyweight.dataMimeType(setupFrame); + } + + @Override + public int keepAliveInterval() { + return SetupFrameFlyweight.keepAliveInterval(setupFrame); + } + + @Override + public int keepAliveMaxLifetime() { + return SetupFrameFlyweight.keepAliveMaxLifetime(setupFrame); + } + + @Override + public int getFlags() { + return FrameHeaderFlyweight.flags(setupFrame); + } + + @Override + public boolean willClientHonorLease() { + return SetupFrameFlyweight.honorLease(setupFrame); + } + + @Override + public boolean isResumeEnabled() { + return SetupFrameFlyweight.resumeEnabled(setupFrame); + } + + @Override + public ByteBuf resumeToken() { + return SetupFrameFlyweight.resumeToken(setupFrame); + } + + @Override + public ConnectionSetupPayload touch() { + setupFrame.touch(); + return this; + } + + @Override + public ConnectionSetupPayload touch(Object hint) { + setupFrame.touch(hint); + return this; + } + + @Override + protected void deallocate() { + setupFrame.release(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java new file mode 100644 index 000000000..3b6b375d1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -0,0 +1,32 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; + +final class PayloadValidationUtils { + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; + + static boolean isValid(int mtu, Payload payload) { + if (mtu > 0) { + return true; + } + + if (payload.hasMetadata()) { + return (((FrameHeaderFlyweight.size() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE + + FrameHeaderFlyweight.size() + + payload.data().readableBytes() + + payload.metadata().readableBytes()) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } else { + return (((FrameHeaderFlyweight.size() + + payload.data().readableBytes() + + FrameLengthFlyweight.FRAME_LENGTH_SIZE) + & ~FrameLengthFlyweight.FRAME_LENGTH_MASK) + == 0); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java new file mode 100644 index 000000000..57aebbdf0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -0,0 +1,355 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.AbstractRSocket; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.lease.LeaseStats; +import io.rsocket.lease.Leases; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketConnector { + private static final String CLIENT_TAG = "client"; + private static final int MIN_MTU_SIZE = 64; + + private static final BiConsumer INVALIDATE_FUNCTION = + (r, i) -> r.onClose().subscribe(null, __ -> i.invalidate(), i::invalidate); + + private Payload setupPayload = EmptyPayload.INSTANCE; + private String metadataMimeType = "application/binary"; + private String dataMimeType = "application/binary"; + + private SocketAcceptor acceptor = (setup, sendingSocket) -> Mono.just(new AbstractRSocket() {}); + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Duration keepAliveInterval = Duration.ofSeconds(20); + private Duration keepAliveMaxLifeTime = Duration.ofSeconds(90); + + private Retry retrySpec; + private Resume resume; + private Supplier> leasesSupplier; + + private int mtu = 0; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private Consumer errorConsumer = ex -> {}; + + private RSocketConnector() {} + + public static RSocketConnector create() { + return new RSocketConnector(); + } + + public static Mono connectWith(ClientTransport transport) { + return RSocketConnector.create().connect(() -> transport); + } + + public RSocketConnector setupPayload(Payload payload) { + this.setupPayload = payload; + return this; + } + + public RSocketConnector dataMimeType(String dataMimeType) { + this.dataMimeType = dataMimeType; + return this; + } + + public RSocketConnector metadataMimeType(String metadataMimeType) { + this.metadataMimeType = metadataMimeType; + return this; + } + + public RSocketConnector interceptors(Consumer consumer) { + consumer.accept(this.interceptors); + return this; + } + + public RSocketConnector acceptor(SocketAcceptor acceptor) { + this.acceptor = acceptor; + return this; + } + + /** + * Set the time {@code interval} between KEEPALIVE frames sent by this client, and the {@code + * maxLifeTime} that this client will allow between KEEPALIVE frames from the server before + * assuming it is dead. + * + *

Note that reasonable values for the time interval may vary significantly. For + * server-to-server connections the spec suggests 500ms, while for for mobile-to-server + * connections it suggests 30+ seconds. In addition {@code maxLifeTime} should allow plenty of + * room for multiple missed ticks from the server. + * + *

By default {@code interval} is set to 20 seconds and {@code maxLifeTime} to 90 seconds. + * + * @param interval the time between KEEPALIVE frames sent, must be greater than 0. + * @param maxLifeTime the max time between KEEPALIVE frames received, must be greater than 0. + */ + public RSocketConnector keepAlive(Duration interval, Duration maxLifeTime) { + if (!interval.negated().isNegative()) { + throw new IllegalArgumentException("`interval` for keepAlive must be > 0"); + } + if (!maxLifeTime.negated().isNegative()) { + throw new IllegalArgumentException("`maxLifeTime` for keepAlive must be > 0"); + } + this.keepAliveInterval = interval; + this.keepAliveMaxLifeTime = maxLifeTime; + return this; + } + + /** + * Enables a reconnectable, shared instance of {@code Mono} so every subscriber will + * observe the same RSocket instance up on connection establishment.
+ * For example: + * + *

{@code
+   * Mono sharedRSocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = sharedRSocketMono.block();
+   *  RSocket r2 = sharedRSocketMono.block();
+   *
+   *  assert r1 == r2;
+   *
+   * }
+ * + * Apart of the shared behavior, if the connection is lost, the same {@code Mono} + * instance will transparently re-establish the connection for subsequent subscribers.
+ * For example: + * + *
{@code
+   * Mono sharedRSocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = sharedRSocketMono.block();
+   *  RSocket r2 = sharedRSocketMono.block();
+   *
+   *  assert r1 == r2;
+   *
+   *  r1.dispose()
+   *
+   *  assert r2.isDisposed()
+   *
+   *  RSocket r3 = sharedRSocketMono.block();
+   *  RSocket r4 = sharedRSocketMono.block();
+   *
+   *  assert r1 != r3;
+   *  assert r4 == r3;
+   *
+   * }
+ * + * Note, having reconnect() enabled does not eliminate the need to accompany each + * individual request with the corresponding retry logic.
+ * For example: + * + *
{@code
+   * Mono sharedRSocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  sharedRSocket.flatMap(rSocket -> rSocket.requestResponse(...))
+   *               .retryWhen(ownRetry)
+   *               .subscribe()
+   *
+   * }
+ * + * @param retrySpec a retry factory applied for {@link Mono#retryWhen(Retry)} + * @return a shared instance of {@code Mono}. + */ + public RSocketConnector reconnect(Retry retrySpec) { + this.retrySpec = Objects.requireNonNull(retrySpec); + return this; + } + + public RSocketConnector resume(Resume resume) { + this.resume = resume; + return this; + } + + public RSocketConnector lease(Supplier> supplier) { + this.leasesSupplier = supplier; + return this; + } + + public RSocketConnector fragment(int mtu) { + if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format("smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } + this.mtu = mtu; + return this; + } + + public RSocketConnector payloadDecoder(PayloadDecoder payloadDecoder) { + Objects.requireNonNull(payloadDecoder); + this.payloadDecoder = payloadDecoder; + return this; + } + + /** + * @deprecated this is deprecated with no replacement and will be removed after {@link + * io.rsocket.RSocketFactory} is removed. + */ + @Deprecated + public RSocketConnector errorConsumer(Consumer errorConsumer) { + Objects.requireNonNull(errorConsumer); + this.errorConsumer = errorConsumer; + return this; + } + + public Mono connect(ClientTransport transport) { + return connect(() -> transport); + } + + public Mono connect(Supplier transportSupplier) { + Mono connectionMono = + Mono.fromSupplier(transportSupplier).flatMap(t -> t.connect(mtu)); + return connectionMono + .flatMap( + connection -> { + ByteBuf resumeToken; + KeepAliveHandler keepAliveHandler; + DuplexConnection wrappedConnection; + + if (resume != null) { + resumeToken = resume.getTokenSupplier().get(); + ClientRSocketSession session = + new ClientRSocketSession( + connection, + resume.getSessionDuration(), + resume.getRetry(), + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken), + resume.getStreamTimeout(), + resume.isCleanupStoreOnKeepAlive()) + .continueWith(connectionMono) + .resumeToken(resumeToken); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler(session.resumableConnection()); + wrappedConnection = session.resumableConnection(); + } else { + resumeToken = Unpooled.EMPTY_BUFFER; + keepAliveHandler = new KeepAliveHandler.DefaultKeepAliveHandler(connection); + wrappedConnection = connection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedConnection, interceptors, true); + + boolean leaseEnabled = leasesSupplier != null; + Leases leases = leaseEnabled ? leasesSupplier.get() : null; + RequesterLeaseHandler requesterLeaseHandler = + leaseEnabled + ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) + : RequesterLeaseHandler.None; + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + errorConsumer, + StreamIdSupplier.clientSupplier(), + mtu, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + requesterLeaseHandler); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + ByteBuf setupFrame = + SetupFrameFlyweight.encode( + wrappedConnection.alloc(), + leaseEnabled, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + resumeToken, + metadataMimeType, + dataMimeType, + setupPayload); + + ConnectionSetupPayload setup = new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .flatMap( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + + ResponderLeaseHandler responderLeaseHandler = + leaseEnabled + ? new ResponderLeaseHandler.Impl<>( + CLIENT_TAG, + wrappedConnection.alloc(), + leases.sender(), + errorConsumer, + leases.stats()) + : ResponderLeaseHandler.None; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + errorConsumer, + responderLeaseHandler, + mtu); + + return wrappedConnection + .sendOne(setupFrame) + .thenReturn(wrappedRSocketRequester); + }); + }) + .as( + source -> { + if (retrySpec != null) { + return new ReconnectMono<>( + source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); + } else { + return source; + } + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java similarity index 55% rename from rsocket-core/src/main/java/io/rsocket/RSocketRequester.java rename to rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 5590a9df0..fefb06003 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -14,15 +14,21 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.Exceptions; import io.rsocket.frame.CancelFrameFlyweight; @@ -37,7 +43,6 @@ import io.rsocket.frame.RequestResponseFrameFlyweight; import io.rsocket.frame.RequestStreamFrameFlyweight; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.RateLimitableRequestPublisher; import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.internal.UnicastMonoEmpty; @@ -48,6 +53,8 @@ import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.util.MonoLifecycleHandler; import java.nio.channels.ClosedChannelException; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; import java.util.function.LongConsumer; @@ -57,9 +64,11 @@ import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.SignalType; import reactor.core.publisher.UnicastProcessor; import reactor.util.concurrent.Queues; @@ -72,6 +81,16 @@ class RSocketRequester implements RSocket { AtomicReferenceFieldUpdater.newUpdater( RSocketRequester.class, Throwable.class, "terminationError"); private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; static { CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); @@ -81,50 +100,54 @@ class RSocketRequester implements RSocket { private final PayloadDecoder payloadDecoder; private final Consumer errorConsumer; private final StreamIdSupplier streamIdSupplier; - private final IntObjectMap senders; + private final IntObjectMap senders; private final IntObjectMap> receivers; private final UnboundedProcessor sendProcessor; + private final int mtu; private final RequesterLeaseHandler leaseHandler; private final ByteBufAllocator allocator; private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; private volatile Throwable terminationError; + private final MonoProcessor onClose; RSocketRequester( - ByteBufAllocator allocator, DuplexConnection connection, PayloadDecoder payloadDecoder, Consumer errorConsumer, StreamIdSupplier streamIdSupplier, + int mtu, int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, RequesterLeaseHandler leaseHandler) { - this.allocator = allocator; this.connection = connection; + this.allocator = connection.alloc(); this.payloadDecoder = payloadDecoder; this.errorConsumer = errorConsumer; this.streamIdSupplier = streamIdSupplier; + this.mtu = mtu; this.leaseHandler = leaseHandler; this.senders = new SynchronizedIntObjectHashMap<>(); this.receivers = new SynchronizedIntObjectHashMap<>(); + this.onClose = MonoProcessor.create(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); connection .onClose() - .doFinally(signalType -> tryTerminateOnConnectionClose()) - .subscribe(null, errorConsumer); + .or(onClose) + .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); connection.receive().subscribe(this::handleIncomingFrames, errorConsumer); if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { KeepAliveSupport keepAliveSupport = - new ClientKeepAliveSupport(allocator, keepAliveTickPeriod, keepAliveAckTimeout); + new ClientKeepAliveSupport(this.allocator, keepAliveTickPeriod, keepAliveAckTimeout); this.keepAliveFramesAcceptor = keepAliveHandler.start( - keepAliveSupport, sendProcessor::onNext, this::tryTerminateOnKeepAlive); + keepAliveSupport, sendProcessor::onNextPrioritized, this::tryTerminateOnKeepAlive); } else { keepAliveFramesAcceptor = null; } @@ -162,17 +185,17 @@ public double availability() { @Override public void dispose() { - connection.dispose(); + tryTerminate(() -> new CancellationException("Disposed")); } @Override public boolean isDisposed() { - return connection.isDisposed(); + return onClose.isDisposed(); } @Override public Mono onClose() { - return connection.onClose(); + return onClose; } private Mono handleFireAndForget(Payload payload) { @@ -182,18 +205,18 @@ private Mono handleFireAndForget(Payload payload) { return Mono.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + final int streamId = streamIdSupplier.nextStreamId(receivers); return UnicastMonoEmpty.newInstance( () -> { ByteBuf requestFrame = - RequestFireAndForgetFrameFlyweight.encode( - allocator, - streamId, - false, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - payload.release(); + RequestFireAndForgetFrameFlyweight.encodeReleasingPayload( + allocator, streamId, payload); sendProcessor.onNext(requestFrame); }); @@ -206,6 +229,11 @@ private Mono handleRequestResponse(final Payload payload) { return Mono.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; @@ -215,13 +243,8 @@ private Mono handleRequestResponse(final Payload payload) { @Override public void doOnSubscribe() { final ByteBuf requestFrame = - RequestResponseFrameFlyweight.encode( - allocator, - streamId, - false, - payload.sliceMetadata().retain(), - payload.sliceData().retain()); - payload.release(); + RequestResponseFrameFlyweight.encodeReleasingPayload( + allocator, streamId, payload); sendProcessor.onNext(requestFrame); } @@ -231,17 +254,16 @@ public void doOnTerminal( @Nonnull SignalType signalType, @Nullable Payload element, @Nullable Throwable e) { - if (signalType == SignalType.ON_ERROR) { - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, e)); - } else if (signalType == SignalType.CANCEL) { + if (signalType == SignalType.CANCEL) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } removeStreamReceiver(streamId); } }); + receivers.put(streamId, receiver); - return receiver; + return receiver.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleRequestStream(final Payload payload) { @@ -251,10 +273,16 @@ private Flux handleRequestStream(final Payload payload) { return Flux.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; final UnicastProcessor receiver = UnicastProcessor.create(); + final AtomicInteger wip = new AtomicInteger(0); receivers.put(streamId, receiver); @@ -266,35 +294,54 @@ private Flux handleRequestStream(final Payload payload) { @Override public void accept(long n) { - if (firstRequest && !receiver.isDisposed()) { + if (firstRequest) { firstRequest = false; - sendProcessor.onNext( - RequestStreamFrameFlyweight.encode( - allocator, - streamId, - false, - n, - payload.sliceMetadata().retain(), - payload.sliceData().retain())); - payload.release(); - } else if (contains(streamId) && !receiver.isDisposed()) { + if (wip.getAndIncrement() != 0) { + // no need to do anything. + // stream was canceled and fist payload has already been discarded + return; + } + int missed = 1; + boolean firstHasBeenSent = false; + for (; ; ) { + if (!firstHasBeenSent) { + sendProcessor.onNext( + RequestStreamFrameFlyweight.encodeReleasingPayload( + allocator, streamId, n, payload)); + firstHasBeenSent = true; + } else { + // if first frame was sent but we cycling again, it means that wip was + // incremented at doOnCancel + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + return; + } + + missed = wip.addAndGet(-missed); + if (missed == 0) { + return; + } + } + } else if (!receiver.isDisposed()) { sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); } } }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - }) .doOnCancel( () -> { - if (contains(streamId) && !receiver.isDisposed()) { + if (wip.getAndIncrement() != 0) { + return; + } + + // check if we need to release payload + // only applicable if the cancel appears earlier than actual request + if (payload.refCnt() > 0) { + payload.release(); + } else { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } }) - .doFinally(s -> removeStreamReceiver(streamId)); + .doFinally(s -> removeStreamReceiver(streamId)) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleChannel(Flux request) { @@ -303,10 +350,89 @@ private Flux handleChannel(Flux request) { return Flux.error(err); } + return request + .switchOnFirst( + (s, flux) -> { + Payload payload = s.get(); + if (payload != null) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + errorConsumer.accept(t); + return Mono.error(t); + } + return handleChannel(payload, flux); + } else { + return flux; + } + }, + false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + } + + private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); final int streamId = streamIdSupplier.nextStreamId(receivers); + final AtomicInteger wip = new AtomicInteger(0); + final UnicastProcessor receiver = UnicastProcessor.create(); + final BaseSubscriber upstreamSubscriber = + new BaseSubscriber() { + + boolean first = true; + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // noops + } + + @Override + protected void hookOnNext(Payload payload) { + if (first) { + // need to skip first since we have already sent it + // no need to release it since it was released earlier on the request establishment + // phase + first = false; + request(1); + return; + } + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + errorConsumer.accept(t); + // no need to send any errors. + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + receiver.onError(t); + return; + } + final ByteBuf frame = + PayloadFrameFlyweight.encodeNextReleasingPayload(allocator, streamId, payload); + + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnComplete() { + ByteBuf frame = PayloadFrameFlyweight.encodeComplete(allocator, streamId); + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnError(Throwable t) { + ByteBuf frame = ErrorFrameFlyweight.encode(allocator, streamId, t); + sendProcessor.onNext(frame); + receiver.onError(t); + } + + @Override + protected void hookFinally(SignalType type) { + senders.remove(streamId, this); + } + }; + return receiver .doOnRequest( new LongConsumer() { @@ -317,85 +443,71 @@ private Flux handleChannel(Flux request) { public void accept(long n) { if (firstRequest) { firstRequest = false; - request - .transform( - f -> { - RateLimitableRequestPublisher wrapped = - RateLimitableRequestPublisher.wrap(f, Queues.SMALL_BUFFER_SIZE); - // Need to set this to one for first the frame - wrapped.request(1); - senders.put(streamId, wrapped); - receivers.put(streamId, receiver); - - return wrapped; - }) - .subscribe( - new BaseSubscriber() { - - boolean firstPayload = true; - - @Override - protected void hookOnNext(Payload payload) { - final ByteBuf frame; - - if (firstPayload) { - firstPayload = false; - frame = - RequestChannelFrameFlyweight.encode( - allocator, - streamId, - false, - false, - n, - payload.sliceMetadata().retain(), - payload.sliceData().retain()); - } else { - frame = - PayloadFrameFlyweight.encode( - allocator, streamId, false, false, true, payload); - } - - sendProcessor.onNext(frame); - payload.release(); - } - - @Override - protected void hookOnComplete() { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - PayloadFrameFlyweight.encodeComplete(allocator, streamId)); - } - if (firstPayload) { - receiver.onComplete(); - } - } - - @Override - protected void hookOnError(Throwable t) { - errorConsumer.accept(t); - receiver.dispose(); - } - }); - } else { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); + if (wip.getAndIncrement() != 0) { + // no need to do anything. + // stream was canceled and fist payload has already been discarded + return; + } + int missed = 1; + boolean firstHasBeenSent = false; + for (; ; ) { + if (!firstHasBeenSent) { + ByteBuf frame; + try { + frame = + RequestChannelFrameFlyweight.encodeReleasingPayload( + allocator, streamId, false, n, initialPayload); + } catch (IllegalReferenceCountException e) { + return; + } + + senders.put(streamId, upstreamSubscriber); + receivers.put(streamId, receiver); + + inboundFlux + .limitRate(Queues.SMALL_BUFFER_SIZE) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); + + sendProcessor.onNext(frame); + firstHasBeenSent = true; + } else { + // if first frame was sent but we cycling again, it means that wip was + // incremented at doOnCancel + senders.remove(streamId, upstreamSubscriber); + receivers.remove(streamId, receiver); + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + return; + } + + missed = wip.addAndGet(-missed); + if (missed == 0) { + return; + } } + } else { + sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); } } }) .doOnError( t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); - } + upstreamSubscriber.cancel(); + receivers.remove(streamId, receiver); }) + .doOnComplete(() -> receivers.remove(streamId, receiver)) .doOnCancel( () -> { - if (contains(streamId) && !receiver.isDisposed()) { + upstreamSubscriber.cancel(); + if (wip.getAndIncrement() != 0) { + return; + } + + // need to send frame only if RequestChannelFrame was sent + if (receivers.remove(streamId, receiver)) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } - }) - .doFinally(s -> removeStreamReceiverAndSender(streamId)); + }); } private Mono handleMetadataPush(Payload payload) { @@ -405,16 +517,21 @@ private Mono handleMetadataPush(Payload payload) { return Mono.error(err); } + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + return UnicastMonoEmpty.newInstance( () -> { ByteBuf metadataPushFrame = - MetadataPushFrameFlyweight.encode(allocator, payload.sliceMetadata().retain()); - payload.release(); + MetadataPushFrameFlyweight.encodeReleasingPayload(allocator, payload); - sendProcessor.onNext(metadataPushFrame); + sendProcessor.onNextPrioritized(metadataPushFrame); }); } + @Nullable private Throwable checkAvailable() { Throwable err = this.terminationError; if (err != null) { @@ -470,46 +587,58 @@ private void handleStreamZero(FrameType type, ByteBuf frame) { private void handleFrame(int streamId, FrameType type, ByteBuf frame) { Subscriber receiver = receivers.get(streamId); - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - } else { - switch (type) { - case ERROR: - receiver.onError(Exceptions.from(streamId, frame)); - receivers.remove(streamId); - break; - case NEXT_COMPLETE: - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); - break; - case CANCEL: - { - RateLimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - break; + switch (type) { + case NEXT: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onNext(payloadDecoder.apply(frame)); + break; + case NEXT_COMPLETE: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onNext(payloadDecoder.apply(frame)); + receiver.onComplete(); + break; + case COMPLETE: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onComplete(); + receivers.remove(streamId); + break; + case ERROR: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onError(Exceptions.from(streamId, frame)); + receivers.remove(streamId); + break; + case CANCEL: + { + Subscription sender = senders.remove(streamId); + if (sender != null) { + sender.cancel(); } - case NEXT: - receiver.onNext(payloadDecoder.apply(frame)); break; - case REQUEST_N: - { - RateLimitableRequestPublisher sender = senders.get(streamId); - if (sender != null) { - int n = RequestNFrameFlyweight.requestN(frame); - sender.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); - } - break; + } + case REQUEST_N: + { + Subscription sender = senders.get(streamId); + if (sender != null) { + long n = RequestNFrameFlyweight.requestN(frame); + sender.request(n); } - case COMPLETE: - receiver.onComplete(); - receivers.remove(streamId); break; - default: - throw new IllegalStateException( - "Client received supported frame on stream " + streamId + ": " + frame.toString()); - } + } + default: + throw new IllegalStateException( + "Client received supported frame on stream " + streamId + ": " + frame.toString()); } } @@ -544,6 +673,10 @@ private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); } + private void tryTerminateOnConnectionError(Throwable e) { + tryTerminate(() -> e); + } + private void tryTerminateOnConnectionClose() { tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); } @@ -552,16 +685,16 @@ private void tryTerminateOnZeroError(ByteBuf errorFrame) { tryTerminate(() -> Exceptions.from(0, errorFrame)); } - private void tryTerminate(Supplier errorSupplier) { + private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { - Exception e = errorSupplier.get(); + Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { terminate(e); } } } - private void terminate(Exception e) { + private void terminate(Throwable e) { connection.dispose(); leaseHandler.dispose(); @@ -593,6 +726,7 @@ private void terminate(Exception e) { receivers.clear(); sendProcessor.dispose(); errorConsumer.accept(e); + onClose.onError(e); } private void removeStreamReceiver(int streamId) { @@ -603,18 +737,6 @@ private void removeStreamReceiver(int streamId) { } } - private void removeStreamReceiverAndSender(int streamId) { - /*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one - of its views*/ - if (terminationError == null) { - receivers.remove(streamId); - RateLimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - } - } - private void handleSendProcessorError(Throwable t) { connection.dispose(); } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java similarity index 59% rename from rsocket-core/src/main/java/io/rsocket/RSocketResponder.java rename to rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 490b00967..b5b298e14 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -14,20 +14,33 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.ResponderRSocket; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.*; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.RateLimitableRequestPublisher; import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.lease.ResponderLeaseHandler; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; +import java.util.function.LongConsumer; +import java.util.function.Supplier; +import javax.annotation.Nullable; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -39,6 +52,17 @@ /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder implements ResponderRSocket { + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); private final DuplexConnection connection; private final RSocket requestHandler; @@ -46,8 +70,16 @@ class RSocketResponder implements ResponderRSocket { private final PayloadDecoder payloadDecoder; private final Consumer errorConsumer; private final ResponderLeaseHandler leaseHandler; + private final Disposable leaseHandlerDisposable; + private final MonoProcessor onClose; + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketResponder.class, Throwable.class, "terminationError"); + + private final int mtu; - private final IntObjectMap sendingLimitableSubscriptions; private final IntObjectMap sendingSubscriptions; private final IntObjectMap> channelProcessors; @@ -55,14 +87,15 @@ class RSocketResponder implements ResponderRSocket { private final ByteBufAllocator allocator; RSocketResponder( - ByteBufAllocator allocator, DuplexConnection connection, RSocket requestHandler, PayloadDecoder payloadDecoder, Consumer errorConsumer, - ResponderLeaseHandler leaseHandler) { - this.allocator = allocator; + ResponderLeaseHandler leaseHandler, + int mtu) { this.connection = connection; + this.allocator = connection.alloc(); + this.mtu = mtu; this.requestHandler = requestHandler; this.responderRSocket = @@ -71,31 +104,23 @@ class RSocketResponder implements ResponderRSocket { this.payloadDecoder = payloadDecoder; this.errorConsumer = errorConsumer; this.leaseHandler = leaseHandler; - this.sendingLimitableSubscriptions = new SynchronizedIntObjectHashMap<>(); this.sendingSubscriptions = new SynchronizedIntObjectHashMap<>(); this.channelProcessors = new SynchronizedIntObjectHashMap<>(); + this.onClose = MonoProcessor.create(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections this.sendProcessor = new UnboundedProcessor<>(); - connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); + connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); - Disposable receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer); - Disposable sendLeaseDisposable = leaseHandler.send(sendProcessor::onNext); + connection.receive().subscribe(this::handleFrame, errorConsumer); + leaseHandlerDisposable = leaseHandler.send(sendProcessor::onNextPrioritized); this.connection .onClose() - .doFinally( - s -> { - cleanup(); - receiveDisposable.dispose(); - sendLeaseDisposable.dispose(); - }) - .subscribe(null, errorConsumer); + .or(onClose) + .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); } private void handleSendProcessorError(Throwable t) { @@ -110,17 +135,6 @@ private void handleSendProcessorError(Throwable t) { } }); - sendingLimitableSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - channelProcessors .values() .forEach( @@ -133,43 +147,21 @@ private void handleSendProcessorError(Throwable t) { }); } - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - - sendingSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + private void tryTerminateOnConnectionError(Throwable e) { + tryTerminate(() -> e); + } - sendingLimitableSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + private void tryTerminateOnConnectionClose() { + tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); + } - channelProcessors - .values() - .forEach( - subscription -> { - try { - subscription.onComplete(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + cleanup(e); + } + } } @Override @@ -252,23 +244,25 @@ public Mono metadataPush(Payload payload) { @Override public void dispose() { - connection.dispose(); + tryTerminate(() -> new CancellationException("Disposed")); } @Override public boolean isDisposed() { - return connection.isDisposed(); + return onClose.isDisposed(); } @Override public Mono onClose() { - return connection.onClose(); + return onClose; } - private void cleanup() { + private void cleanup(Throwable e) { cleanUpSendingSubscriptions(); - cleanUpChannelProcessors(); + cleanUpChannelProcessors(e); + connection.dispose(); + leaseHandlerDisposable.dispose(); requestHandler.dispose(); sendProcessor.dispose(); } @@ -276,13 +270,19 @@ private void cleanup() { private synchronized void cleanUpSendingSubscriptions() { sendingSubscriptions.values().forEach(Subscription::cancel); sendingSubscriptions.clear(); - - sendingLimitableSubscriptions.values().forEach(Subscription::cancel); - sendingLimitableSubscriptions.clear(); } - private synchronized void cleanUpChannelProcessors() { - channelProcessors.values().forEach(Processor::onComplete); + private synchronized void cleanUpChannelProcessors(Throwable e) { + channelProcessors + .values() + .forEach( + payloadPayloadProcessor -> { + try { + payloadPayloadProcessor.onError(e); + } catch (Throwable t) { + // noops + } + }); channelProcessors.clear(); } @@ -305,12 +305,12 @@ private void handleFrame(ByteBuf frame) { handleRequestN(streamId, frame); break; case REQUEST_STREAM: - int streamInitialRequestN = RequestStreamFrameFlyweight.initialRequestN(frame); + long streamInitialRequestN = RequestStreamFrameFlyweight.initialRequestN(frame); Payload streamPayload = payloadDecoder.apply(frame); - handleStream(streamId, requestStream(streamPayload), streamInitialRequestN); + handleStream(streamId, requestStream(streamPayload), streamInitialRequestN, null); break; case REQUEST_CHANNEL: - int channelInitialRequestN = RequestChannelFrameFlyweight.initialRequestN(frame); + long channelInitialRequestN = RequestChannelFrameFlyweight.initialRequestN(frame); Payload channelPayload = payloadDecoder.apply(frame); handleChannel(streamId, channelPayload, channelInitialRequestN); break; @@ -384,32 +384,28 @@ protected void hookFinally(SignalType type) { } private void handleRequestResponse(int streamId, Mono response) { - response.subscribe( + final BaseSubscriber subscriber = new BaseSubscriber() { private boolean isEmpty = true; - @Override - protected void hookOnSubscribe(Subscription subscription) { - sendingSubscriptions.put(streamId, subscription); - subscription.request(Long.MAX_VALUE); - } - @Override protected void hookOnNext(Payload payload) { if (isEmpty) { isEmpty = false; } - ByteBuf byteBuf; - try { - byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload); - } catch (Throwable t) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); - throw Exceptions.propagate(t); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; } - payload.release(); - + ByteBuf byteBuf = + PayloadFrameFlyweight.encodeNextCompleteReleasingPayload( + allocator, streamId, payload); sendProcessor.onNext(byteBuf); } @@ -427,69 +423,128 @@ protected void hookOnComplete() { @Override protected void hookFinally(SignalType type) { - sendingSubscriptions.remove(streamId); + sendingSubscriptions.remove(streamId, this); } - }); + }; + + sendingSubscriptions.put(streamId, subscriber); + response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); } - private void handleStream(int streamId, Flux response, int initialRequestN) { - response - .transform( - frameFlux -> { - RateLimitableRequestPublisher payloads = - RateLimitableRequestPublisher.wrap(frameFlux, Queues.SMALL_BUFFER_SIZE); - sendingLimitableSubscriptions.put(streamId, payloads); - payloads.request( - initialRequestN >= Integer.MAX_VALUE ? Long.MAX_VALUE : initialRequestN); - return payloads; - }) - .subscribe( - new BaseSubscriber() { - - @Override - protected void hookOnNext(Payload payload) { - ByteBuf byteBuf; - try { - byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload); - } catch (Throwable t) { - payload.release(); - throw Exceptions.propagate(t); - } - - payload.release(); - - sendProcessor.onNext(byteBuf); - } + private void handleStream( + int streamId, + Flux response, + long initialRequestN, + @Nullable UnicastProcessor requestChannel) { + final BaseSubscriber subscriber = + new BaseSubscriber() { - @Override - protected void hookOnComplete() { - sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId)); - } + @Override + protected void hookOnSubscribe(Subscription s) { + s.request(initialRequestN); + } - @Override - protected void hookOnError(Throwable throwable) { - handleError(streamId, throwable); + @Override + protected void hookOnNext(Payload payload) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + // specifically for requestChannel case so when Payload is invalid we will not be + // sending CancelFrame and ErrorFrame + // Note: CancelFrame is redundant and due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both Requester and Responder. + // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both the Requester and Responder. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); } + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; + } + + ByteBuf byteBuf = + PayloadFrameFlyweight.encodeNextReleasingPayload(allocator, streamId, payload); + sendProcessor.onNext(byteBuf); + } - @Override - protected void hookFinally(SignalType type) { - sendingLimitableSubscriptions.remove(streamId); + @Override + protected void hookOnComplete() { + sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId)); + } + + @Override + protected void hookOnError(Throwable throwable) { + handleError(streamId, throwable); + } + + @Override + protected void hookOnCancel() { + // specifically for requestChannel case so when requester sends Cancel frame so the + // whole chain MUST be terminated + // Note: CancelFrame is redundant from the responder side due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving a CANCEL, the stream is terminated on the Responder. + // Upon sending a CANCEL, the stream is terminated on the Requester. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + try { + requestChannel.dispose(); + } catch (Exception e) { + // might be thrown back if stream is cancelled } - }); + } + } + + @Override + protected void hookFinally(SignalType type) { + sendingSubscriptions.remove(streamId); + } + }; + + sendingSubscriptions.put(streamId, subscriber); + response + .limitRate(Queues.SMALL_BUFFER_SIZE) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(subscriber); } - private void handleChannel(int streamId, Payload payload, int initialRequestN) { + private void handleChannel(int streamId, Payload payload, long initialRequestN) { UnicastProcessor frames = UnicastProcessor.create(); channelProcessors.put(streamId, frames); Flux payloads = frames - .doOnCancel( - () -> sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId))) - .doOnError(t -> handleError(streamId, t)) .doOnRequest( - l -> sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, l))) - .doFinally(signalType -> channelProcessors.remove(streamId)); + new LongConsumer() { + boolean first = true; + + @Override + public void accept(long l) { + long n; + if (first) { + first = false; + n = l - 1L; + } else { + n = l; + } + if (n > 0) { + sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); + } + } + }) + .doFinally( + signalType -> { + if (channelProcessors.remove(streamId, frames)) { + if (signalType == SignalType.CANCEL) { + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + } + } + }) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); // not chained, as the payload should be enqueued in the Unicast processor before this method // returns @@ -497,9 +552,9 @@ private void handleChannel(int streamId, Payload payload, int initialRequestN) { frames.onNext(payload); if (responderRSocket != null) { - handleStream(streamId, requestChannel(payload, payloads), initialRequestN); + handleStream(streamId, requestChannel(payload, payloads), initialRequestN, frames); } else { - handleStream(streamId, requestChannel(payloads), initialRequestN); + handleStream(streamId, requestChannel(payloads), initialRequestN, frames); } } @@ -520,10 +575,7 @@ protected void hookOnError(Throwable throwable) { private void handleCancelFrame(int streamId) { Subscription subscription = sendingSubscriptions.remove(streamId); - - if (subscription == null) { - subscription = sendingLimitableSubscriptions.remove(streamId); - } + channelProcessors.remove(streamId); if (subscription != null) { subscription.cancel(); @@ -538,13 +590,9 @@ private void handleError(int streamId, Throwable t) { private void handleRequestN(int streamId, ByteBuf frame) { Subscription subscription = sendingSubscriptions.get(streamId); - if (subscription == null) { - subscription = sendingLimitableSubscriptions.get(streamId); - } - if (subscription != null) { - int n = RequestNFrameFlyweight.requestN(frame); - subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); + long n = RequestNFrameFlyweight.requestN(frame); + subscription.request(n); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java new file mode 100644 index 000000000..c82a2f40a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -0,0 +1,284 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.AbstractRSocket; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.exceptions.InvalidSetupException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.lease.Leases; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.resume.SessionManager; +import io.rsocket.transport.ServerTransport; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.publisher.Mono; + +public final class RSocketServer { + private static final String SERVER_TAG = "server"; + private static final int MIN_MTU_SIZE = 64; + + private SocketAcceptor acceptor = (setup, sendingSocket) -> Mono.just(new AbstractRSocket() {}); + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + private int mtu = 0; + + private Resume resume; + private Supplier> leasesSupplier = null; + + private Consumer errorConsumer = ex -> {}; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketServer() {} + + public static RSocketServer create() { + return new RSocketServer(); + } + + public static RSocketServer create(SocketAcceptor acceptor) { + return RSocketServer.create().acceptor(acceptor); + } + + public RSocketServer acceptor(SocketAcceptor acceptor) { + Objects.requireNonNull(acceptor); + this.acceptor = acceptor; + return this; + } + + public RSocketServer interceptors(Consumer consumer) { + consumer.accept(this.interceptors); + return this; + } + + public RSocketServer fragment(int mtu) { + if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format("smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } + this.mtu = mtu; + return this; + } + + public RSocketServer resume(Resume resume) { + this.resume = resume; + return this; + } + + public RSocketServer lease(Supplier> supplier) { + this.leasesSupplier = supplier; + return this; + } + + public RSocketServer payloadDecoder(PayloadDecoder payloadDecoder) { + Objects.requireNonNull(payloadDecoder); + this.payloadDecoder = payloadDecoder; + return this; + } + + /** + * @deprecated this is deprecated with no replacement and will be removed after {@link + * io.rsocket.RSocketFactory} is removed. + */ + @Deprecated + public RSocketServer errorConsumer(Consumer errorConsumer) { + this.errorConsumer = errorConsumer; + return this; + } + + public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { + return new ServerTransport.ConnectionAcceptor() { + private final ServerSetup serverSetup = serverSetup(); + + @Override + public Mono apply(DuplexConnection connection) { + return acceptor(serverSetup, connection); + } + }; + } + + public Mono bind(ServerTransport transport) { + return Mono.defer( + new Supplier>() { + ServerSetup serverSetup = serverSetup(); + + @Override + public Mono get() { + return transport + .start(duplexConnection -> acceptor(serverSetup, duplexConnection), mtu) + .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); + } + }); + } + + private Mono acceptor(ServerSetup serverSetup, DuplexConnection connection) { + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(connection, interceptors, false); + + return multiplexer + .asSetupConnection() + .receive() + .next() + .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer)); + } + + private Mono acceptResume( + ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { + return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); + } + + private Mono accept( + ServerSetup serverSetup, ByteBuf startFrame, ClientServerInputMultiplexer multiplexer) { + switch (FrameHeaderFlyweight.frameType(startFrame)) { + case SETUP: + return acceptSetup(serverSetup, startFrame, multiplexer); + case RESUME: + return acceptResume(serverSetup, startFrame, multiplexer); + default: + return serverSetup + .sendError( + multiplexer, + new InvalidSetupException( + "invalid setup frame: " + FrameHeaderFlyweight.frameType(startFrame))) + .doFinally( + signalType -> { + startFrame.release(); + multiplexer.dispose(); + }); + } + } + + private Mono acceptSetup( + ServerSetup serverSetup, ByteBuf setupFrame, ClientServerInputMultiplexer multiplexer) { + + if (!SetupFrameFlyweight.isSupportedVersion(setupFrame)) { + return serverSetup + .sendError( + multiplexer, + new InvalidSetupException( + "Unsupported version: " + SetupFrameFlyweight.humanReadableVersion(setupFrame))) + .doFinally( + signalType -> { + setupFrame.release(); + multiplexer.dispose(); + }); + } + + boolean leaseEnabled = leasesSupplier != null; + if (SetupFrameFlyweight.honorLease(setupFrame) && !leaseEnabled) { + return serverSetup + .sendError(multiplexer, new InvalidSetupException("lease is not supported")) + .doFinally( + signalType -> { + setupFrame.release(); + multiplexer.dispose(); + }); + } + + return serverSetup.acceptRSocketSetup( + setupFrame, + multiplexer, + (keepAliveHandler, wrappedMultiplexer) -> { + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(setupFrame); + + Leases leases = leaseEnabled ? leasesSupplier.get() : null; + RequesterLeaseHandler requesterLeaseHandler = + leaseEnabled + ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) + : RequesterLeaseHandler.None; + + RSocket rSocketRequester = + new RSocketRequester( + wrappedMultiplexer.asServerConnection(), + payloadDecoder, + errorConsumer, + StreamIdSupplier.serverSupplier(), + mtu, + setupPayload.keepAliveInterval(), + setupPayload.keepAliveMaxLifetime(), + keepAliveHandler, + requesterLeaseHandler); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setupPayload, wrappedRSocketRequester) + .onErrorResume( + err -> + serverSetup + .sendError(multiplexer, rejectedSetupError(err)) + .then(Mono.error(err))) + .doOnNext( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + DuplexConnection connection = wrappedMultiplexer.asClientConnection(); + + ResponderLeaseHandler responderLeaseHandler = + leaseEnabled + ? new ResponderLeaseHandler.Impl<>( + SERVER_TAG, + connection.alloc(), + leases.sender(), + errorConsumer, + leases.stats()) + : ResponderLeaseHandler.None; + + RSocket rSocketResponder = + new RSocketResponder( + connection, + wrappedRSocketHandler, + payloadDecoder, + errorConsumer, + responderLeaseHandler, + mtu); + }) + .doFinally(signalType -> setupPayload.release()) + .then(); + }); + } + + private ServerSetup serverSetup() { + return resume != null ? createSetup() : new ServerSetup.DefaultServerSetup(); + } + + ServerSetup createSetup() { + return new ServerSetup.ResumableServerSetup( + new SessionManager(), + resume.getSessionDuration(), + resume.getStreamTimeout(), + resume.getStoreFactory(SERVER_TAG), + resume.isCleanupStoreOnKeepAlive()); + } + + private Exception rejectedSetupError(Throwable err) { + String msg = err.getMessage(); + return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java new file mode 100644 index 000000000..81f6625f0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Operators.MonoSubscriber; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class ReconnectMono extends Mono implements Invalidatable, Disposable, Scannable { + + final Mono source; + final BiConsumer onValueReceived; + final Consumer onValueExpired; + final ReconnectMainSubscriber mainSubscriber; + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ReconnectMono.class, "wip"); + + volatile ReconnectInner[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMono.class, ReconnectInner[].class, "subscribers"); + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] EMPTY_UNSUBSCRIBED = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] EMPTY_SUBSCRIBED = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] READY = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] TERMINATED = new ReconnectInner[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + ReconnectMono( + Mono source, + Consumer onValueExpired, + BiConsumer onValueReceived) { + this.source = source; + this.onValueExpired = onValueExpired; + this.onValueReceived = onValueReceived; + this.mainSubscriber = new ReconnectMainSubscriber<>(this); + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return source; + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + + final boolean isDisposed = isDisposed(); + if (key == Attr.TERMINATED) return isDisposed; + if (key == Attr.ERROR) return t; + + return null; + } + + @Override + public void dispose() { + this.terminate(new CancellationException("ReconnectMono has already been disposed")); + } + + @Override + public boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ReconnectInner inner = new ReconnectInner<>(actual, this); + actual.onSubscribe(inner); + + final int state = this.add(inner); + + if (state == READY_STATE) { + inner.complete(this.value); + } else if (state == TERMINATED_STATE) { + inner.onError(this.t); + } + } + + /** + * Block the calling thread indefinitely, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @return the value of this {@code ReconnectMono} + */ + @Override + @Nullable + public T block() { + return block(null); + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@code ReconnectMono} or {@code null} if the timeout is reached and + * the {@code ReconnectMono} has not completed + */ + @Override + @Nullable + @SuppressWarnings("uncheked") + public T block(@Nullable Duration timeout) { + try { + ReconnectInner[] subscribers = this.subscribers; + if (subscribers == READY) { + return this.value; + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("ReconnectMono terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.source.subscribe(this.mainSubscriber); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + ReconnectInner[] inners = this.subscribers; + + if (inners == READY) { + return this.value; + } + if (inners == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = + Exceptions.addSuppressed(re, new Exception("ReconnectMono terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + void terminate(Throwable t) { + if (isDisposed()) { + return; + } + + // writes happens before volatile write + this.t = t; + + final ReconnectInner[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.mainSubscriber.dispose(); + + this.doFinally(); + + for (CoreSubscriber consumer : subscribers) { + consumer.onError(t); + } + } + + void complete() { + ReconnectInner[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + return; + } + + final T value = this.value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.onValueReceived.accept(value, this); + + for (ReconnectInner consumer : subscribers) { + consumer.complete(value); + } + } + + void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + + if (value != null && isDisposed()) { + this.value = null; + this.onValueExpired.accept(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + // Check RSocket is not good + @Override + public void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final ReconnectInner[] subscribers = this.subscribers; + + if (subscribers == READY && SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED)) { + final T value = this.value; + this.value = null; + + if (value != null) { + this.onValueExpired.accept(value); + } + } + } + + int add(ReconnectInner ps) { + for (; ; ) { + ReconnectInner[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + ReconnectInner[] b = new ReconnectInner[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.source.subscribe(this.mainSubscriber); + } + return ADDED_STATE; + } + } + } + + @SuppressWarnings("unchecked") + void remove(ReconnectInner ps) { + for (; ; ) { + ReconnectInner[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + ReconnectInner[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new ReconnectInner[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } + + static final class ReconnectMainSubscriber implements CoreSubscriber { + + final ReconnectMono parent; + + volatile Subscription s; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMainSubscriber.class, Subscription.class, "s"); + + ReconnectMainSubscriber(ReconnectMono parent) { + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final ReconnectMono p = this.parent; + final T value = p.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + p.doFinally(); + return; + } + + if (value == null) { + p.terminate(new IllegalStateException("Unexpected Completion of the Upstream")); + } else { + p.complete(); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + final ReconnectMono p = this.parent; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + p.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // terminate upstream which means retryBackoff has exhausted + p.terminate(t); + } + + @Override + public void onNext(T value) { + if (this.s == Operators.cancelledSubscription()) { + this.parent.onValueExpired.accept(value); + return; + } + + final ReconnectMono p = this.parent; + + p.value = value; + // volatile write and check on racing + p.doFinally(); + } + + void dispose() { + Operators.terminate(S, this); + } + } + + static final class ReconnectInner extends MonoSubscriber { + final ReconnectMono parent; + + ReconnectInner(CoreSubscriber actual, ReconnectMono parent) { + super(actual); + this.parent = parent; + } + + @Override + public void cancel() { + if (!isCancelled()) { + super.cancel(); + this.parent.remove(this); + } + } + + @Override + public void onComplete() { + if (!isCancelled()) { + this.actual.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (isCancelled()) { + Operators.onErrorDropped(t, currentContext()); + } else { + this.actual.onError(t); + } + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return super.scanUnsafe(key); + } + } +} + +interface Invalidatable { + + void invalidate(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java new file mode 100644 index 000000000..aedcc9e5e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -0,0 +1,102 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.ResumeFrameFlyweight; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.ResumableFramesStore; +import java.time.Duration; +import java.util.function.Function; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.retry.Retry; + +public class Resume { + private static final Logger logger = LoggerFactory.getLogger(Resume.class); + + private Duration sessionDuration = Duration.ofMinutes(2); + private Duration streamTimeout = Duration.ofSeconds(10); + private boolean cleanupStoreOnKeepAlive; + private Function storeFactory; + + private Supplier tokenSupplier = ResumeFrameFlyweight::generateResumeToken; + private Retry retry = + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1)) + .maxBackoff(Duration.ofSeconds(16)) + .jitter(1.0) + .doBeforeRetry(signal -> logger.debug("Connection error", signal.failure())); + + public Resume() {} + + public Resume sessionDuration(Duration sessionDuration) { + this.sessionDuration = sessionDuration; + return this; + } + + public Resume streamTimeout(Duration streamTimeout) { + this.streamTimeout = streamTimeout; + return this; + } + + public Resume cleanupStoreOnKeepAlive() { + this.cleanupStoreOnKeepAlive = true; + return this; + } + + public Resume storeFactory( + Function storeFactory) { + this.storeFactory = storeFactory; + return this; + } + + public Resume token(Supplier supplier) { + this.tokenSupplier = supplier; + return this; + } + + public Resume retry(Retry retry) { + this.retry = retry; + return this; + } + + Duration getSessionDuration() { + return sessionDuration; + } + + Duration getStreamTimeout() { + return streamTimeout; + } + + boolean isCleanupStoreOnKeepAlive() { + return cleanupStoreOnKeepAlive; + } + + Function getStoreFactory(String tag) { + return storeFactory != null + ? storeFactory + : token -> new InMemoryResumableFramesStore(tag, 100_000); + } + + Supplier getTokenSupplier() { + return tokenSupplier; + } + + Retry getRetry() { + return retry; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java similarity index 80% rename from rsocket-core/src/main/java/io/rsocket/internal/ServerSetup.java rename to rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java index dbd8bc173..3e20d3c60 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ServerSetup.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,41 +14,44 @@ * limitations under the License. */ -package io.rsocket.internal; +package io.rsocket.core; import static io.rsocket.keepalive.KeepAliveHandler.*; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; import io.rsocket.exceptions.RejectedResumeException; import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.frame.ErrorFrameFlyweight; import io.rsocket.frame.ResumeFrameFlyweight; import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.resume.*; -import io.rsocket.util.ConnectionUtils; import java.time.Duration; import java.util.function.BiFunction; import java.util.function.Function; import reactor.core.publisher.Mono; -public interface ServerSetup { +abstract class ServerSetup { - Mono acceptRSocketSetup( + abstract Mono acceptRSocketSetup( ByteBuf frame, ClientServerInputMultiplexer multiplexer, BiFunction> then); - Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); + abstract Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); - default void dispose() {} + void dispose() {} - class DefaultServerSetup implements ServerSetup { - private final ByteBufAllocator allocator; + Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { + DuplexConnection duplexConnection = multiplexer.asSetupConnection(); + return duplexConnection + .sendOne(ErrorFrameFlyweight.encode(duplexConnection.alloc(), 0, exception)) + .onErrorResume(err -> Mono.empty()); + } - public DefaultServerSetup(ByteBufAllocator allocator) { - this.allocator = allocator; - } + static class DefaultServerSetup extends ServerSetup { @Override public Mono acceptRSocketSetup( @@ -78,28 +81,21 @@ public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexe multiplexer.dispose(); }); } - - private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - return ConnectionUtils.sendError(allocator, multiplexer, exception); - } } - class ResumableServerSetup implements ServerSetup { - private final ByteBufAllocator allocator; + static class ResumableServerSetup extends ServerSetup { private final SessionManager sessionManager; private final Duration resumeSessionDuration; private final Duration resumeStreamTimeout; private final Function resumeStoreFactory; private final boolean cleanupStoreOnKeepAlive; - public ResumableServerSetup( - ByteBufAllocator allocator, + ResumableServerSetup( SessionManager sessionManager, Duration resumeSessionDuration, Duration resumeStreamTimeout, Function resumeStoreFactory, boolean cleanupStoreOnKeepAlive) { - this.allocator = allocator; this.sessionManager = sessionManager; this.resumeSessionDuration = resumeSessionDuration; this.resumeStreamTimeout = resumeStreamTimeout; @@ -121,7 +117,6 @@ public Mono acceptRSocketSetup( .save( new ServerRSocketSession( multiplexer.asClientServerConnection(), - allocator, resumeSessionDuration, resumeStreamTimeout, resumeStoreFactory, @@ -155,10 +150,6 @@ public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexe } } - private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - return ConnectionUtils.sendError(allocator, multiplexer, exception); - } - @Override public void dispose() { sessionManager.dispose(); diff --git a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java similarity index 98% rename from rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java rename to rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java index af8c6b3d0..70734b8c0 100644 --- a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java +++ b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import io.netty.util.collection.IntObjectMap; import java.util.concurrent.atomic.AtomicLongFieldUpdater; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/SupportsIterator.java b/rsocket-core/src/main/java/io/rsocket/core/package-info.java similarity index 61% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/SupportsIterator.java rename to rsocket-core/src/main/java/io/rsocket/core/package-info.java index 50d2a326f..a70bb3b16 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/SupportsIterator.java +++ b/rsocket-core/src/main/java/io/rsocket/core/package-info.java @@ -1,9 +1,11 @@ /* + * Copyright 2015-2020 the original author or authors. + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -11,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.queues; -import io.rsocket.internal.jctools.util.InternalAPI; +/** + * Contains core RSocket protocol, client and server implementation classes, including factories to + * create and configure them. + */ +@NonNullApi +package io.rsocket.core; -/** Tagging interface to help testing */ -@InternalAPI -public interface SupportsIterator {} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java index e92534b2a..351e045a3 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * Application layer logic generating a Reactive Streams {@code onError} event. @@ -32,10 +33,9 @@ public final class ApplicationErrorException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ApplicationErrorException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public ApplicationErrorException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ApplicationErrorException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.APPLICATION_ERROR; + public ApplicationErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.APPLICATION_ERROR, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java index 984e8249b..537cf2bf2 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The Responder canceled the request but may have started processing it (similar to REJECTED but @@ -33,10 +34,9 @@ public final class CanceledException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public CanceledException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public CanceledException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public CanceledException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CANCELED; + public CanceledException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.CANCELED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java index 3f4f4309d..f1f1a47d8 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The connection is being terminated. Sender or Receiver of this frame MUST wait for outstanding @@ -33,10 +34,9 @@ public final class ConnectionCloseException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ConnectionCloseException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public ConnectionCloseException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ConnectionCloseException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CONNECTION_CLOSE; + public ConnectionCloseException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.CONNECTION_CLOSE, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java index beaa3d0d0..9581cfc97 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The connection is being terminated. Sender or Receiver of this frame MAY close the connection @@ -33,10 +34,9 @@ public final class ConnectionErrorException extends RSocketException implements * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ConnectionErrorException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public ConnectionErrorException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ConnectionErrorException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CONNECTION_ERROR; + public ConnectionErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.CONNECTION_ERROR, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java index 6315206b5..5c1154ebd 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -1,28 +1,36 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; public class CustomRSocketException extends RSocketException { private static final long serialVersionUID = 7873267740343446585L; - private final int errorCode; - /** * Constructs a new exception with the specified message. * * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] * @param message the message - * @throws NullPointerException if {@code message} is {@code null} * @throws IllegalArgumentException if {@code errorCode} is out of allowed range */ public CustomRSocketException(int errorCode, String message) { - super(message); - if (errorCode > ErrorType.MAX_USER_ALLOWED_ERROR_CODE - && errorCode < ErrorType.MIN_USER_ALLOWED_ERROR_CODE) { - throw new IllegalArgumentException( - "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]"); - } - this.errorCode = errorCode; + this(errorCode, message, null); } /** @@ -31,21 +39,14 @@ public CustomRSocketException(int errorCode, String message) { * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} * @throws IllegalArgumentException if {@code errorCode} is out of allowed range */ - public CustomRSocketException(int errorCode, String message, Throwable cause) { - super(message, cause); - if (errorCode > ErrorType.MAX_USER_ALLOWED_ERROR_CODE - && errorCode < ErrorType.MIN_USER_ALLOWED_ERROR_CODE) { + public CustomRSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + if (errorCode > ErrorFrameFlyweight.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorFrameFlyweight.MIN_USER_ALLOWED_ERROR_CODE) { throw new IllegalArgumentException( - "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]"); + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]", this); } - this.errorCode = errorCode; - } - - @Override - public int errorCode() { - return errorCode; } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java index 3a10410f0..fe2d304f5 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,21 @@ package io.rsocket.exceptions; -import static io.rsocket.frame.ErrorFrameFlyweight.*; +import static io.rsocket.frame.ErrorFrameFlyweight.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameFlyweight.CANCELED; +import static io.rsocket.frame.ErrorFrameFlyweight.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameFlyweight.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameFlyweight.INVALID; +import static io.rsocket.frame.ErrorFrameFlyweight.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameFlyweight.MAX_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameFlyweight.MIN_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameFlyweight.REJECTED; +import static io.rsocket.frame.ErrorFrameFlyweight.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameFlyweight.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameFlyweight.UNSUPPORTED_SETUP; import io.netty.buffer.ByteBuf; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameFlyweight; import java.util.Objects; @@ -28,10 +40,10 @@ public final class Exceptions { private Exceptions() {} /** - * Create a {@link RSocketException} from a Frame that matches the error code it contains. + * Create a {@link RSocketErrorException} from a Frame that matches the error code it contains. * * @param frame the frame to retrieve the error code and message from - * @return a {@link RSocketException} that matches the error code in the Frame + * @return a {@link RSocketErrorException} that matches the error code in the Frame * @throws NullPointerException if {@code frame} is {@code null} */ public static RuntimeException from(int streamId, ByteBuf frame) { diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java index 4783b1590..a4b28659f 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The request is invalid. @@ -32,10 +33,9 @@ public final class InvalidException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public InvalidException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public InvalidException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public InvalidException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.INVALID; + public InvalidException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.INVALID, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java index b3705d5b7..1ff53d51d 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The Setup frame is invalid for the server (it could be that the client is too recent for the old @@ -33,10 +34,9 @@ public final class InvalidSetupException extends SetupException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public InvalidSetupException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public InvalidSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public InvalidSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.INVALID_SETUP; + public InvalidSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.INVALID_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java index 7508a1ee3..93c49d5e2 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,41 +16,48 @@ package io.rsocket.exceptions; -import java.util.Objects; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameFlyweight; import reactor.util.annotation.Nullable; -/** The root of the RSocket exception hierarchy. */ -public abstract class RSocketException extends RuntimeException { +/** + * The root of the RSocket exception hierarchy. + * + * @deprecated please use {@link RSocketErrorException} instead + */ +@Deprecated +public abstract class RSocketException extends RSocketErrorException { private static final long serialVersionUID = 2912815394105575423L; /** - * Constructs a new exception with the specified message. + * Constructs a new exception with the specified message and error code 0x201 (Application error). * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RSocketException(String message) { - super(Objects.requireNonNull(message, "message must not be null")); + this(message, null); } /** - * Constructs a new exception with the specified message and cause. + * Constructs a new exception with the specified message and cause and error code 0x201 + * (Application error). * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} is {@code null} */ public RSocketException(String message, @Nullable Throwable cause) { - super(Objects.requireNonNull(message, "message must not be null"), cause); + super(ErrorFrameFlyweight.APPLICATION_ERROR, message, cause); } /** - * Returns the RSocket error code - * represented by this exception + * Constructs a new exception with the specified error code, message and cause. * - * @return the RSocket error code + * @param errorCode the RSocket protocol error code + * @param message the message + * @param cause the cause of this exception */ - public abstract int errorCode(); + public RSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java index 4ab83182e..3fad3f396 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * Despite being a valid request, the Responder decided to reject it. The Responder guarantees that @@ -34,10 +35,9 @@ public class RejectedException extends RSocketException implements Retryable { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedException(String message) { - super(message); + this(message, null); } /** @@ -45,14 +45,8 @@ public RejectedException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED; + public RejectedException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.REJECTED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java index 0d4116538..a10eb4197 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The server rejected the resume, it can specify the reason in the payload. @@ -32,10 +33,9 @@ public final class RejectedResumeException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedResumeException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public RejectedResumeException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedResumeException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED_RESUME; + public RejectedResumeException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.REJECTED_RESUME, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java index 1fa5f604e..6b5dc0f8b 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * The server rejected the setup, it can specify the reason in the payload. @@ -32,10 +33,9 @@ public final class RejectedSetupException extends SetupException implements Retr * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedSetupException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public RejectedSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED_SETUP; + public RejectedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.REJECTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java index 2111a51b1..712508f0b 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ package io.rsocket.exceptions; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; + /** The root of the setup exception hierarchy. */ public abstract class SetupException extends RSocketException { @@ -25,10 +28,11 @@ public abstract class SetupException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} + * @deprecated please use {@link #SetupException(int, String, Throwable)} */ + @Deprecated public SetupException(String message) { - super(message); + this(message, null); } /** @@ -36,9 +40,21 @@ public SetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @deprecated please use {@link #SetupException(int, String, Throwable)} + */ + @Deprecated + public SetupException(String message, @Nullable Throwable cause) { + this(ErrorFrameFlyweight.INVALID_SETUP, message, cause); + } + + /** + * Constructs a new exception with the specified error code, message and cause. + * + * @param errorCode the RSocket protocol code + * @param message the message + * @param cause the cause of this exception */ - public SetupException(String message, Throwable cause) { - super(message, cause); + public SetupException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java index 7d14bc5d2..b112b95be 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameFlyweight; +import javax.annotation.Nullable; /** * Some (or all) of the parameters specified by the client are unsupported by the server. @@ -32,10 +33,9 @@ public final class UnsupportedSetupException extends SetupException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public UnsupportedSetupException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public UnsupportedSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public UnsupportedSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.UNSUPPORTED_SETUP; + public UnsupportedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameFlyweight.UNSUPPORTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index cbe989d4b..316643e10 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import static io.rsocket.fragmentation.FrameFragmenter.fragmentFrame; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.rsocket.DuplexConnection; import io.rsocket.frame.FrameHeaderFlyweight; @@ -40,29 +39,25 @@ * href="https://github.com/rsocket/rsocket/blob/master/Protocol.md#fragmentation-and-reassembly">Fragmentation * and Reassembly */ -public final class FragmentationDuplexConnection implements DuplexConnection { +public final class FragmentationDuplexConnection extends ReassemblyDuplexConnection + implements DuplexConnection { private static final int MIN_MTU_SIZE = 64; private static final Logger logger = LoggerFactory.getLogger(FragmentationDuplexConnection.class); private final DuplexConnection delegate; private final int mtu; - private final ByteBufAllocator allocator; private final FrameReassembler frameReassembler; private final boolean encodeLength; private final String type; public FragmentationDuplexConnection( - DuplexConnection delegate, - ByteBufAllocator allocator, - int mtu, - boolean encodeLength, - String type) { + DuplexConnection delegate, int mtu, boolean encodeAndEncodeLength, String type) { + super(delegate, encodeAndEncodeLength); + Objects.requireNonNull(delegate, "delegate must not be null"); - Objects.requireNonNull(allocator, "byteBufAllocator must not be null"); - this.encodeLength = encodeLength; - this.allocator = allocator; + this.encodeLength = encodeAndEncodeLength; this.delegate = delegate; this.mtu = assertMtu(mtu); - this.frameReassembler = new FrameReassembler(allocator); + this.frameReassembler = new FrameReassembler(delegate.alloc()); this.type = type; delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); @@ -110,7 +105,7 @@ public Mono sendOne(ByteBuf frame) { if (shouldFragment(frameType, readableBytes)) { if (logger.isDebugEnabled()) { return delegate.send( - Flux.from(fragmentFrame(allocator, mtu, frame, frameType, encodeLength)) + Flux.from(fragmentFrame(alloc(), mtu, frame, frameType, encodeLength)) .doOnNext( byteBuf -> { ByteBuf f = encodeLength ? FrameLengthFlyweight.frame(byteBuf) : byteBuf; @@ -123,7 +118,7 @@ public Mono sendOne(ByteBuf frame) { })); } else { return delegate.send( - Flux.from(fragmentFrame(allocator, mtu, frame, frameType, encodeLength))); + Flux.from(fragmentFrame(alloc(), mtu, frame, frameType, encodeLength))); } } else { return delegate.sendOne(encode(frame)); @@ -132,38 +127,9 @@ public Mono sendOne(ByteBuf frame) { private ByteBuf encode(ByteBuf frame) { if (encodeLength) { - return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); - } else { - return frame; - } - } - - private ByteBuf decode(ByteBuf frame) { - if (encodeLength) { - return FrameLengthFlyweight.frame(frame).retain(); + return FrameLengthFlyweight.encode(alloc(), frame.readableBytes(), frame); } else { return frame; } } - - @Override - public Flux receive() { - return delegate - .receive() - .handle( - (byteBuf, sink) -> { - ByteBuf decode = decode(byteBuf); - frameReassembler.reassembleFrame(decode, sink); - }); - } - - @Override - public Mono onClose() { - return delegate.onClose(); - } - - @Override - public void dispose() { - delegate.dispose(); - } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index d634f7374..8593d2be7 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,14 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.*; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import java.util.function.Consumer; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java index 0c446a7c4..1a8d242b2 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -166,8 +166,8 @@ void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { header = frame.copy(frame.readerIndex(), FrameHeaderFlyweight.size()); if (frameType == FrameType.REQUEST_CHANNEL || frameType == FrameType.REQUEST_STREAM) { - int i = RequestChannelFrameFlyweight.initialRequestN(frame); - header.writeInt(i); + long i = RequestChannelFrameFlyweight.initialRequestN(frame); + header.writeInt(i > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) i); } putHeader(streamId, header); } @@ -261,10 +261,16 @@ void reassembleFrame(ByteBuf frame, SynchronousSink sink) { private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf header) { ByteBuf metadata; CompositeByteBuf cm = removeMetadata(streamId); - if (cm != null) { - metadata = cm.addComponents(true, PayloadFrameFlyweight.metadata(frame).retain()); + + ByteBuf decodedMetadata = PayloadFrameFlyweight.metadata(frame); + if (decodedMetadata != null) { + if (cm != null) { + metadata = cm.addComponents(true, decodedMetadata.retain()); + } else { + metadata = PayloadFrameFlyweight.metadata(frame).retain(); + } } else { - metadata = PayloadFrameFlyweight.metadata(frame).retain(); + metadata = cm != null ? cm : null; } ByteBuf data = assembleData(frame, streamId); diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java new file mode 100644 index 000000000..933755bb2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java @@ -0,0 +1,92 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.fragmentation; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameLengthFlyweight; +import java.util.Objects; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A {@link DuplexConnection} implementation that reassembles {@link ByteBuf}s. + * + * @see Fragmentation + * and Reassembly + */ +public class ReassemblyDuplexConnection implements DuplexConnection { + private final DuplexConnection delegate; + private final FrameReassembler frameReassembler; + private final boolean decodeLength; + + public ReassemblyDuplexConnection(DuplexConnection delegate, boolean decodeLength) { + Objects.requireNonNull(delegate, "delegate must not be null"); + this.decodeLength = decodeLength; + this.delegate = delegate; + this.frameReassembler = new FrameReassembler(delegate.alloc()); + + delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); + } + + @Override + public Mono send(Publisher frames) { + return delegate.send(frames); + } + + @Override + public Mono sendOne(ByteBuf frame) { + return delegate.sendOne(frame); + } + + private ByteBuf decode(ByteBuf frame) { + if (decodeLength) { + return FrameLengthFlyweight.frame(frame).retain(); + } else { + return frame; + } + } + + @Override + public Flux receive() { + return delegate + .receive() + .handle( + (byteBuf, sink) -> { + ByteBuf decode = decode(byteBuf); + frameReassembler.reassembleFrame(decode, sink); + }); + } + + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + + @Override + public Mono onClose() { + return delegate.onClose(); + } + + @Override + public void dispose() { + delegate.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java index 4431f98dd..8cc3fb41a 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java index e4b16fec7..73bfd38f1 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java @@ -3,7 +3,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; -import io.rsocket.buffer.TupleByteBuf; class DataAndMetadataFlyweight { public static final int FRAME_LENGTH_MASK = 0xFFFFFF; @@ -31,38 +30,61 @@ private static int decodeLength(final ByteBuf byteBuf) { return length; } - static ByteBuf encodeOnlyMetadata( - ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata) { - return TupleByteBuf.of(allocator, header, metadata); - } - - static ByteBuf encodeOnlyData(ByteBufAllocator allocator, final ByteBuf header, ByteBuf data) { - return TupleByteBuf.of(allocator, header, data); - } - static ByteBuf encode( - ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata, ByteBuf data) { + ByteBufAllocator allocator, + final ByteBuf header, + ByteBuf metadata, + boolean hasMetadata, + ByteBuf data) { - int length = metadata.readableBytes(); - encodeLength(header, length); - return TupleByteBuf.of(allocator, header, metadata, data); - } + final boolean addData; + if (data != null) { + if (data.isReadable()) { + addData = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + data.release(); + addData = false; + } + } else { + addData = false; + } - static ByteBuf metadataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { + final boolean addMetadata; if (hasMetadata) { - int length = decodeLength(byteBuf); - return byteBuf.readSlice(length); + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } } else { - return Unpooled.EMPTY_BUFFER; + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (hasMetadata) { + int length = metadata.readableBytes(); + encodeLength(header, length); + } + + if (addMetadata && addData) { + return allocator.compositeBuffer(3).addComponents(true, header, metadata, data); + } else if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else if (addData) { + return allocator.compositeBuffer(2).addComponents(true, header, data); + } else { + return header; } } - static ByteBuf metadata(ByteBuf byteBuf, boolean hasMetadata) { - byteBuf.markReaderIndex(); - byteBuf.skipBytes(6); - ByteBuf metadata = metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; + static ByteBuf metadataWithoutMarking(ByteBuf byteBuf) { + int length = decodeLength(byteBuf); + return byteBuf.readSlice(length); } static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { @@ -77,12 +99,4 @@ static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { return Unpooled.EMPTY_BUFFER; } } - - static ByteBuf data(ByteBuf byteBuf, boolean hasMetadata) { - byteBuf.markReaderIndex(); - byteBuf.skipBytes(6); - ByteBuf data = dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java index df9d39ba8..ab26233f1 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java @@ -3,7 +3,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; -import io.rsocket.exceptions.RSocketException; +import io.rsocket.RSocketErrorException; import java.nio.charset.StandardCharsets; public class ErrorFrameFlyweight { @@ -28,7 +28,10 @@ public static ByteBuf encode( ByteBufAllocator allocator, int streamId, Throwable t, ByteBuf data) { ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.ERROR, 0); - int errorCode = errorCodeFromException(t); + int errorCode = + t instanceof RSocketErrorException + ? ((RSocketErrorException) t).errorCode() + : APPLICATION_ERROR; header.writeInt(errorCode); @@ -41,14 +44,6 @@ public static ByteBuf encode(ByteBufAllocator allocator, int streamId, Throwable return encode(allocator, streamId, t, data); } - public static int errorCodeFromException(Throwable t) { - if (t instanceof RSocketException) { - return ((RSocketException) t).errorCode(); - } - - return APPLICATION_ERROR; - } - public static int errorCode(ByteBuf byteBuf) { byteBuf.markReaderIndex(); byteBuf.skipBytes(FrameHeaderFlyweight.size()); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java index ffd99930d..b41a5d59e 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java @@ -5,7 +5,9 @@ * * @see Error * Codes + * @deprecated please use constants in {@link ErrorFrameFlyweight}. */ +@Deprecated public final class ErrorType { /** diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java index df8b308e9..8cb01b08f 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java @@ -14,21 +14,18 @@ public static ByteBuf encode( @Nullable ByteBuf metadata, ByteBuf data) { + final boolean hasMetadata = metadata != null; + int flags = FrameHeaderFlyweight.FLAGS_I; - if (metadata != null) { + if (hasMetadata) { flags |= FrameHeaderFlyweight.FLAGS_M; } - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.EXT, flags); + final ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.EXT, flags); header.writeInt(extendedType); - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + + return DataAndMetadataFlyweight.encode(allocator, header, metadata, hasMetadata, data); } public static int extendedType(ByteBuf byteBuf) { @@ -56,10 +53,13 @@ public static ByteBuf metadata(ByteBuf byteBuf) { FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } byteBuf.markReaderIndex(); // Extended type byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); + ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf); byteBuf.resetReaderIndex(); return metadata; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java index 06efeab6c..a91d52782 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java @@ -13,12 +13,7 @@ public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, B public static ByteBuf encode( final ByteBufAllocator allocator, ByteBuf header, @Nullable ByteBuf metadata, ByteBuf data) { - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + final boolean hasMetadata = metadata != null; + return DataAndMetadataFlyweight.encode(allocator, header, metadata, hasMetadata, data); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java index 6011263fa..622160061 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java @@ -2,7 +2,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.buffer.TupleByteBuf; /** * Some transports like TCP aren't framed, and require a length. This is used by DuplexConnections @@ -35,7 +34,7 @@ private static int decodeLength(final ByteBuf byteBuf) { public static ByteBuf encode(ByteBufAllocator allocator, int length, ByteBuf frame) { ByteBuf buffer = allocator.buffer(); encodeLength(buffer, length); - return TupleByteBuf.of(allocator, buffer, frame); + return allocator.compositeBuffer(2).addComponents(true, buffer, frame); } public static int length(ByteBuf byteBuf) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java index 0d2175fb6..6662d34af 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -22,6 +22,16 @@ public static String toString(ByteBuf frame) { .append(Integer.toBinaryString(FrameHeaderFlyweight.flags(frame))) .append(" Length: " + frame.readableBytes()); + if (frameType.hasInitialRequestN()) { + payload + .append(" InitialRequestN: ") + .append(RequestStreamFrameFlyweight.initialRequestN(frame)); + } + + if (frameType == FrameType.REQUEST_N) { + payload.append(" RequestN: ").append(RequestNFrameFlyweight.requestN(frame)); + } + if (FrameHeaderFlyweight.hasMetadata(frame)) { payload.append("\nMetadata:\n"); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java index e4e6029b3..b591412a6 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java @@ -29,7 +29,7 @@ public static ByteBuf encode( header.writeLong(lp); - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); + return DataAndMetadataFlyweight.encode(allocator, header, null, false, data); } public static boolean respondFlag(ByteBuf byteBuf) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java index 4676f4c9d..32f086a15 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java @@ -13,21 +13,38 @@ public static ByteBuf encode( final int numRequests, @Nullable final ByteBuf metadata) { + final boolean hasMetadata = metadata != null; + int flags = 0; - if (metadata != null) { + if (hasMetadata) { flags |= FrameHeaderFlyweight.FLAGS_M; } - ByteBuf header = + final ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.LEASE, flags) .writeInt(ttl) .writeInt(numRequests); - if (metadata == null) { - return header; + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); } else { - return DataAndMetadataFlyweight.encodeOnlyMetadata(allocator, header, metadata); + return header; } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java index d37b573ba..a39acef92 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java @@ -2,8 +2,24 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; public class MetadataPushFrameFlyweight { + + public static ByteBuf encodeReleasingPayload(ByteBufAllocator allocator, Payload payload) { + final ByteBuf metadata = payload.metadata().retain(); + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + metadata.release(); + throw e; + } + return encode(allocator, metadata); + } + public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { ByteBuf header = FrameHeaderFlyweight.encodeStreamZero( diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java index 4f67d9c72..53ac6150b 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class PayloadFrameFlyweight { @@ -9,6 +10,53 @@ public class PayloadFrameFlyweight { private PayloadFrameFlyweight() {} + public static ByteBuf encodeNextReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + return encodeReleasingPayload(allocator, streamId, false, payload); + } + + public static ByteBuf encodeNextCompleteReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, true, payload); + } + + static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, boolean complete, Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return encode(allocator, streamId, false, complete, true, metadata, data); + } + + public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { + return encode(allocator, streamId, false, true, false, null, null); + } + public static ByteBuf encode( ByteBufAllocator allocator, int streamId, @@ -21,53 +69,6 @@ public static ByteBuf encode( allocator, streamId, fragmentFollows, complete, next, 0, metadata, data); } - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - fragmentFollows, - complete, - next, - 0, - payload.hasMetadata() ? payload.metadata().retain() : null, - payload.data().retain()); - } - - public static ByteBuf encodeNextComplete( - ByteBufAllocator allocator, int streamId, Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - false, - true, - true, - 0, - payload.hasMetadata() ? payload.metadata().retain() : null, - payload.data().retain()); - } - - public static ByteBuf encodeNext(ByteBufAllocator allocator, int streamId, Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - false, - false, - true, - 0, - payload.hasMetadata() ? payload.metadata().retain() : null, - payload.data().retain()); - } - - public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { - return FLYWEIGHT.encode(allocator, streamId, false, true, false, 0, null, null); - } - public static ByteBuf data(ByteBuf byteBuf) { return FLYWEIGHT.data(byteBuf); } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java index 06ddcda03..c0db21170 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestChannelFrameFlyweight { @@ -10,25 +11,40 @@ public class RequestChannelFrameFlyweight { private RequestChannelFrameFlyweight() {} - public static ByteBuf encode( + public static ByteBuf encodeReleasingPayload( ByteBufAllocator allocator, int streamId, - boolean fragmentFollows, boolean complete, - long requestN, + long initialRequestN, Payload payload) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } - return FLYWEIGHT.encode( - allocator, - streamId, - fragmentFollows, - complete, - false, - reqN, - payload.metadata(), - payload.data()); + return encode(allocator, streamId, false, complete, initialRequestN, metadata, data); } public static ByteBuf encode( @@ -36,11 +52,15 @@ public static ByteBuf encode( int streamId, boolean fragmentFollows, boolean complete, - long requestN, + long initialRequestN, ByteBuf metadata, ByteBuf data) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; return FLYWEIGHT.encode( allocator, streamId, fragmentFollows, complete, false, reqN, metadata, data); @@ -54,7 +74,8 @@ public static ByteBuf metadata(ByteBuf byteBuf) { return FLYWEIGHT.metadataWithRequestN(byteBuf); } - public static int initialRequestN(ByteBuf byteBuf) { - return FLYWEIGHT.initialRequestN(byteBuf); + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = FLYWEIGHT.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java index 5f2d606e4..e091edcc3 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestFireAndForgetFrameFlyweight { @@ -10,11 +11,36 @@ public class RequestFireAndForgetFrameFlyweight { private RequestFireAndForgetFrameFlyweight() {} - public static ByteBuf encode( - ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { - - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, payload.metadata(), payload.data()); + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return FLYWEIGHT.encode(allocator, streamId, false, metadata, data); } public static ByteBuf encode( diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java index 98d862f36..15fac9f55 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java @@ -29,9 +29,12 @@ ByteBuf encode( int requestN, @Nullable ByteBuf metadata, ByteBuf data) { + + final boolean hasMetadata = metadata != null; + int flags = 0; - if (metadata != null) { + if (hasMetadata) { flags |= FrameHeaderFlyweight.FLAGS_M; } @@ -47,19 +50,13 @@ ByteBuf encode( flags |= FrameHeaderFlyweight.FLAGS_N; } - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, frameType, flags); + final ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, frameType, flags); if (requestN > 0) { header.writeInt(requestN); } - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + return DataAndMetadataFlyweight.encode(allocator, header, metadata, hasMetadata, data); } ByteBuf data(ByteBuf byteBuf) { @@ -73,9 +70,12 @@ ByteBuf data(ByteBuf byteBuf) { ByteBuf metadata(ByteBuf byteBuf) { boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } byteBuf.markReaderIndex(); byteBuf.skipBytes(FrameHeaderFlyweight.size()); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); + ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf); byteBuf.resetReaderIndex(); return metadata; } @@ -91,9 +91,12 @@ ByteBuf dataWithRequestN(ByteBuf byteBuf) { ByteBuf metadataWithRequestN(ByteBuf byteBuf) { boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } byteBuf.markReaderIndex(); byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); + ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf); byteBuf.resetReaderIndex(); return metadata; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java index 5a4c4c273..fe2c752cf 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java @@ -8,26 +8,23 @@ private RequestNFrameFlyweight() {} public static ByteBuf encode( final ByteBufAllocator allocator, final int streamId, long requestN) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return encode(allocator, streamId, reqN); - } - - public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId, int requestN) { - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.REQUEST_N, 0); if (requestN < 1) { throw new IllegalArgumentException("request n is less than 1"); } - return header.writeInt(requestN); + int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + + ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.REQUEST_N, 0); + return header.writeInt(reqN); } - public static int requestN(ByteBuf byteBuf) { + public static long requestN(ByteBuf byteBuf) { FrameHeaderFlyweight.ensureFrameType(FrameType.REQUEST_N, byteBuf); byteBuf.markReaderIndex(); byteBuf.skipBytes(FrameHeaderFlyweight.size()); int i = byteBuf.readInt(); byteBuf.resetReaderIndex(); - return i; + return i == Integer.MAX_VALUE ? Long.MAX_VALUE : i; } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java index 2e06c9b82..782c70965 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestResponseFrameFlyweight { @@ -10,9 +11,36 @@ public class RequestResponseFrameFlyweight { private RequestResponseFrameFlyweight() {} - public static ByteBuf encode( - ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { - return encode(allocator, streamId, fragmentFollows, payload.metadata(), payload.data()); + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return encode(allocator, streamId, false, metadata, data); } public static ByteBuf encode( diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java index 171c41990..2fb209ffb 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestStreamFrameFlyweight { @@ -10,46 +11,54 @@ public class RequestStreamFrameFlyweight { private RequestStreamFrameFlyweight() {} - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - long requestN, - Payload payload) { - return encode( - allocator, streamId, fragmentFollows, requestN, payload.metadata(), payload.data()); - } + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, long initialRequestN, Payload payload) { - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - int requestN, - Payload payload) { - return encode( - allocator, streamId, fragmentFollows, requestN, payload.metadata(), payload.data()); - } + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - long requestN, - ByteBuf metadata, - ByteBuf data) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return encode(allocator, streamId, fragmentFollows, reqN, metadata, data); + return encode(allocator, streamId, false, initialRequestN, metadata, data); } public static ByteBuf encode( ByteBufAllocator allocator, int streamId, boolean fragmentFollows, - int requestN, + long initialRequestN, ByteBuf metadata, ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, false, false, requestN, metadata, data); + allocator, streamId, fragmentFollows, false, false, reqN, metadata, data); } public static ByteBuf data(ByteBuf byteBuf) { @@ -60,7 +69,8 @@ public static ByteBuf metadata(ByteBuf byteBuf) { return FLYWEIGHT.metadataWithRequestN(byteBuf); } - public static int initialRequestN(ByteBuf byteBuf) { - return FLYWEIGHT.initialRequestN(byteBuf); + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = FLYWEIGHT.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java index 9f92e715f..bfb73fe22 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java @@ -55,8 +55,9 @@ public static ByteBuf encode( final String dataMimeType, final Payload setupPayload) { - ByteBuf metadata = setupPayload.hasMetadata() ? setupPayload.sliceMetadata() : null; - ByteBuf data = setupPayload.sliceData(); + final ByteBuf data = setupPayload.sliceData(); + final boolean hasMetadata = setupPayload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? setupPayload.sliceMetadata() : null; int flags = 0; @@ -68,11 +69,11 @@ public static ByteBuf encode( flags |= FLAGS_WILL_HONOR_LEASE; } - if (metadata != null) { + if (hasMetadata) { flags |= FrameHeaderFlyweight.FLAGS_M; } - ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.SETUP, flags); + final ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.SETUP, flags); header.writeInt(CURRENT_VERSION).writeInt(keepaliveInterval).writeInt(maxLifetime); @@ -91,13 +92,8 @@ public static ByteBuf encode( length = ByteBufUtil.utf8Bytes(dataMimeType); header.writeByte(length); ByteBufUtil.writeUtf8(header, dataMimeType); - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + + return DataAndMetadataFlyweight.encode(allocator, header, metadata, hasMetadata, data); } public static int version(ByteBuf byteBuf) { @@ -192,9 +188,12 @@ public static String dataMimeType(ByteBuf byteBuf) { public static ByteBuf metadata(ByteBuf byteBuf) { boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } byteBuf.markReaderIndex(); skipToPayload(byteBuf); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); + ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf); byteBuf.resetReaderIndex(); return metadata; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java index 692dcb363..0a77e3820 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -3,8 +3,15 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import io.rsocket.frame.*; -import io.rsocket.util.ByteBufPayload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameFlyweight; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; +import io.rsocket.util.DefaultPayload; import java.nio.ByteBuffer; /** Default Frame decoder that copies the frames contents for easy of use. */ @@ -45,14 +52,18 @@ public Payload apply(ByteBuf byteBuf) { throw new IllegalArgumentException("unsupported frame type: " + type); } - ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); ByteBuffer data = ByteBuffer.allocateDirect(d.readableBytes()); - data.put(d.nioBuffer()); data.flip(); - metadata.put(m.nioBuffer()); - metadata.flip(); - return ByteBufPayload.create(data, metadata); + if (m != null) { + ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); + metadata.put(m.nioBuffer()); + metadata.flip(); + + return DefaultPayload.create(data, metadata); + } + + return DefaultPayload.create(data); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java index 0b63590e8..c92f82428 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java @@ -3,7 +3,14 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import io.rsocket.frame.*; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameFlyweight; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import io.rsocket.util.ByteBufPayload; /** @@ -46,6 +53,6 @@ public Payload apply(ByteBuf byteBuf) { throw new IllegalArgumentException("unsupported frame type: " + type); } - return ByteBufPayload.create(d.retain(), m.retain()); + return ByteBufPayload.create(d.retain(), m != null ? m.retain() : null); } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index 1e76b6898..cf3eeb120 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -17,12 +17,13 @@ package io.rsocket.internal; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameUtil; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.plugins.InitializingInterceptorRegistry; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,7 +46,8 @@ */ public class ClientServerInputMultiplexer implements Closeable { private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); - private static final PluginRegistry emptyPluginRegistry = new PluginRegistry(); + private static final InitializingInterceptorRegistry emptyInterceptorRegistry = + new InitializingInterceptorRegistry(); private final DuplexConnection setupConnection; private final DuplexConnection serverConnection; @@ -54,23 +56,23 @@ public class ClientServerInputMultiplexer implements Closeable { private final DuplexConnection clientServerConnection; public ClientServerInputMultiplexer(DuplexConnection source) { - this(source, emptyPluginRegistry, false); + this(source, emptyInterceptorRegistry, false); } public ClientServerInputMultiplexer( - DuplexConnection source, PluginRegistry plugins, boolean isClient) { + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { this.source = source; final MonoProcessor> setup = MonoProcessor.create(); final MonoProcessor> server = MonoProcessor.create(); final MonoProcessor> client = MonoProcessor.create(); - source = plugins.applyConnection(Type.SOURCE, source); + source = registry.initConnection(Type.SOURCE, source); setupConnection = - plugins.applyConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); + registry.initConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); serverConnection = - plugins.applyConnection(Type.SERVER, new InternalDuplexConnection(source, server)); + registry.initConnection(Type.SERVER, new InternalDuplexConnection(source, server)); clientConnection = - plugins.applyConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); + registry.initConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); clientServerConnection = new InternalDuplexConnection(source, client, server); source @@ -200,6 +202,11 @@ public Flux receive() { })); } + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + @Override public void dispose() { source.dispose(); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientSetup.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientSetup.java deleted file mode 100644 index 38217bdc2..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientSetup.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import static io.rsocket.keepalive.KeepAliveHandler.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.DuplexConnection; -import io.rsocket.keepalive.KeepAliveHandler; -import io.rsocket.resume.ClientRSocketSession; -import io.rsocket.resume.ResumableDuplexConnection; -import io.rsocket.resume.ResumableFramesStore; -import io.rsocket.resume.ResumeStrategy; -import java.time.Duration; -import java.util.function.Supplier; -import reactor.core.publisher.Mono; - -public interface ClientSetup { - - DuplexConnection connection(); - - KeepAliveHandler keepAliveHandler(); - - ByteBuf resumeToken(); - - class DefaultClientSetup implements ClientSetup { - private final DuplexConnection connection; - - public DefaultClientSetup(DuplexConnection connection) { - this.connection = connection; - } - - @Override - public DuplexConnection connection() { - return connection; - } - - @Override - public KeepAliveHandler keepAliveHandler() { - return new DefaultKeepAliveHandler(connection); - } - - @Override - public ByteBuf resumeToken() { - return Unpooled.EMPTY_BUFFER; - } - } - - class ResumableClientSetup implements ClientSetup { - private final ByteBuf resumeToken; - private final ResumableDuplexConnection duplexConnection; - private final ResumableKeepAliveHandler keepAliveHandler; - - public ResumableClientSetup( - ByteBufAllocator allocator, - DuplexConnection connection, - Mono newConnectionFactory, - ByteBuf resumeToken, - ResumableFramesStore resumableFramesStore, - Duration resumeSessionDuration, - Duration resumeStreamTimeout, - Supplier resumeStrategySupplier, - boolean cleanupStoreOnKeepAlive) { - - ClientRSocketSession rSocketSession = - new ClientRSocketSession( - connection, - allocator, - resumeSessionDuration, - resumeStrategySupplier, - resumableFramesStore, - resumeStreamTimeout, - cleanupStoreOnKeepAlive) - .continueWith(newConnectionFactory) - .resumeToken(resumeToken); - this.duplexConnection = rSocketSession.resumableConnection(); - this.keepAliveHandler = new ResumableKeepAliveHandler(duplexConnection); - this.resumeToken = resumeToken; - } - - @Override - public DuplexConnection connection() { - return duplexConnection; - } - - @Override - public KeepAliveHandler keepAliveHandler() { - return keepAliveHandler; - } - - @Override - public ByteBuf resumeToken() { - return resumeToken; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java deleted file mode 100755 index cdb0d0c0c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java +++ /dev/null @@ -1,242 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; - -/** */ -public class RateLimitableRequestPublisher extends Flux implements Subscription { - - private static final int NOT_CANCELED_STATE = 0; - private static final int CANCELED_STATE = 1; - - private final Publisher source; - - private volatile int canceled; - private static final AtomicIntegerFieldUpdater CANCELED = - AtomicIntegerFieldUpdater.newUpdater(RateLimitableRequestPublisher.class, "canceled"); - - private final long prefetch; - private final long limit; - - private long externalRequested; // need sync - private int pendingToFulfil; // need sync since should be checked/zerroed in onNext - // and increased in request - private int deliveredElements; // no need to sync since increased zerroed only in - // the request method - - private boolean subscribed; - - private @Nullable Subscription internalSubscription; - - private RateLimitableRequestPublisher(Publisher source, long prefetch) { - this.source = source; - this.prefetch = prefetch; - this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : (prefetch - (prefetch >> 2)); - } - - public static RateLimitableRequestPublisher wrap(Publisher source, long prefetch) { - return new RateLimitableRequestPublisher<>(source, prefetch); - } - - @Override - public void subscribe(CoreSubscriber destination) { - synchronized (this) { - if (subscribed) { - throw new IllegalStateException("only one subscriber at a time"); - } - - subscribed = true; - } - final InnerOperator s = new InnerOperator(destination); - - source.subscribe(s); - destination.onSubscribe(s); - } - - @Override - public void request(long n) { - synchronized (this) { - long requested = externalRequested; - if (requested == Long.MAX_VALUE) { - return; - } - externalRequested = Operators.addCap(n, requested); - } - - requestN(); - } - - private void requestN() { - final long r; - final Subscription s; - - synchronized (this) { - s = internalSubscription; - if (s == null) { - return; - } - - final long er = externalRequested; - final long p = prefetch; - final int pendingFulfil = pendingToFulfil; - - if (er != Long.MAX_VALUE || p != Integer.MAX_VALUE) { - // shortcut - if (pendingFulfil == p) { - return; - } - - r = Math.min(p - pendingFulfil, er); - if (er != Long.MAX_VALUE) { - externalRequested -= r; - } - if (p != Integer.MAX_VALUE) { - pendingToFulfil += r; - } - } else { - r = Long.MAX_VALUE; - } - } - - if (r > 0) { - s.request(r); - } - } - - public void cancel() { - if (!isCanceled() && CANCELED.compareAndSet(this, NOT_CANCELED_STATE, CANCELED_STATE)) { - Subscription s; - - synchronized (this) { - s = internalSubscription; - internalSubscription = null; - subscribed = false; - } - - if (s != null) { - s.cancel(); - } - } - } - - private boolean isCanceled() { - return canceled == CANCELED_STATE; - } - - private class InnerOperator implements CoreSubscriber, Subscription { - final Subscriber destination; - - private InnerOperator(Subscriber destination) { - this.destination = destination; - } - - @Override - public void onSubscribe(Subscription s) { - synchronized (RateLimitableRequestPublisher.this) { - RateLimitableRequestPublisher.this.internalSubscription = s; - - if (isCanceled()) { - s.cancel(); - subscribed = false; - RateLimitableRequestPublisher.this.internalSubscription = null; - } - } - - requestN(); - } - - @Override - public void onNext(T t) { - try { - destination.onNext(t); - - if (prefetch == Integer.MAX_VALUE) { - return; - } - - final long l = limit; - int d = deliveredElements + 1; - - if (d == l) { - d = 0; - final long r; - final Subscription s; - - synchronized (RateLimitableRequestPublisher.this) { - long er = externalRequested; - s = internalSubscription; - - if (s == null) { - return; - } - - if (er >= l) { - er -= l; - // keep pendingToFulfil as is since it is eq to prefetch - r = l; - } else { - pendingToFulfil -= l; - if (er > 0) { - r = er; - er = 0; - pendingToFulfil += r; - } else { - r = 0; - } - } - - externalRequested = er; - } - - if (r > 0) { - s.request(r); - } - } - - deliveredElements = d; - } catch (Throwable e) { - onError(e); - } - } - - @Override - public void onError(Throwable t) { - destination.onError(t); - } - - @Override - public void onComplete() { - destination.onComplete(); - } - - @Override - public void request(long n) {} - - @Override - public void cancel() { - RateLimitableRequestPublisher.this.cancel(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index dfcc13a64..cb8b5d63d 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -43,40 +43,58 @@ public final class UnboundedProcessor extends FluxProcessor implements Fuseable.QueueSubscription, Fuseable { + final Queue queue; + final Queue priorityQueue; + + volatile boolean done; + Throwable error; + // important to not loose the downstream too early and miss discard hook, while + // having relevant hasDownstreams() + boolean hasDownstream; + volatile CoreSubscriber actual; + + volatile boolean cancelled; + + volatile int once; + @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater ONCE = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "once"); + volatile int wip; + @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater WIP = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "wip"); + volatile int discardGuard; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater DISCARD_GUARD = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); + + volatile long requested; + @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); - final Queue queue; - volatile boolean done; - Throwable error; - volatile CoreSubscriber actual; - volatile boolean cancelled; - volatile int once; - volatile int wip; - volatile long requested; - volatile boolean outputFused; + boolean outputFused; public UnboundedProcessor() { this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); } @Override public int getBufferSize() { - return Queues.capacity(this.queue); + return Integer.MAX_VALUE; } @Override public Object scanUnsafe(Attr key) { if (Attr.BUFFERED == key) return queue.size(); + if (Attr.PREFETCH == key) return Integer.MAX_VALUE; return super.scanUnsafe(key); } @@ -84,6 +102,7 @@ void drainRegular(Subscriber a) { int missed = 1; final Queue q = queue; + final Queue pq = priorityQueue; for (; ; ) { @@ -93,10 +112,18 @@ void drainRegular(Subscriber a) { while (r != e) { boolean d = done; - T t = q.poll(); - boolean empty = t == null; + T t; + boolean empty; + + if (!pq.isEmpty()) { + t = pq.poll(); + empty = false; + } else { + t = q.poll(); + empty = t == null; + } - if (checkTerminated(d, empty, a, q)) { + if (checkTerminated(d, empty, a)) { return; } @@ -110,7 +137,7 @@ void drainRegular(Subscriber a) { } if (r == e) { - if (checkTerminated(done, q.isEmpty(), a, q)) { + if (checkTerminated(done, q.isEmpty() && pq.isEmpty(), a)) { return; } } @@ -129,13 +156,11 @@ void drainRegular(Subscriber a) { void drainFused(Subscriber a) { int missed = 1; - final Queue q = queue; - for (; ; ) { if (cancelled) { - q.clear(); - actual = null; + this.clear(); + hasDownstream = false; return; } @@ -144,7 +169,7 @@ void drainFused(Subscriber a) { a.onNext(null); if (d) { - actual = null; + hasDownstream = false; Throwable ex = error; if (ex != null) { @@ -164,6 +189,9 @@ void drainFused(Subscriber a) { public void drain() { if (WIP.getAndIncrement(this) != 0) { + if (cancelled) { + this.clear(); + } return; } @@ -188,20 +216,15 @@ public void drain() { } } - boolean checkTerminated(boolean d, boolean empty, Subscriber a, Queue q) { + boolean checkTerminated(boolean d, boolean empty, Subscriber a) { if (cancelled) { - while (!q.isEmpty()) { - T t = q.poll(); - if (t != null) { - release(t); - } - } - actual = null; + this.clear(); + hasDownstream = false; return true; } if (d && empty) { Throwable e = error; - actual = null; + hasDownstream = false; if (e != null) { a.onError(e); } else { @@ -222,10 +245,6 @@ public void onSubscribe(Subscription s) { } } - public long available() { - return requested; - } - @Override public int getPrefetch() { return Integer.MAX_VALUE; @@ -237,6 +256,23 @@ public Context currentContext() { return actual != null ? actual.currentContext() : Context.empty(); } + public void onNextPrioritized(T t) { + if (done || cancelled) { + Operators.onNextDropped(t, currentContext()); + release(t); + return; + } + + if (!priorityQueue.offer(t)) { + Throwable ex = + Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); + onError(Operators.onOperatorError(null, ex, t, currentContext())); + release(t); + return; + } + drain(); + } + @Override public void onNext(T t) { if (done || cancelled) { @@ -287,7 +323,7 @@ public void subscribe(CoreSubscriber actual) { actual.onSubscribe(this); this.actual = actual; if (cancelled) { - this.actual = null; + this.hasDownstream = false; } else { drain(); } @@ -314,38 +350,56 @@ public void cancel() { cancelled = true; if (WIP.getAndIncrement(this) == 0) { - clear(); - actual = null; + this.clear(); + hasDownstream = false; } } - @Override - public T peek() { - return queue.peek(); - } - @Override @Nullable public T poll() { + Queue pq = this.priorityQueue; + if (!pq.isEmpty()) { + return pq.poll(); + } return queue.poll(); } @Override public int size() { - return queue.size(); + return priorityQueue.size() + queue.size(); } @Override public boolean isEmpty() { - return queue.isEmpty(); + return priorityQueue.isEmpty() && queue.isEmpty(); } @Override public void clear() { - while (!queue.isEmpty()) { - T t = queue.poll(); - if (t != null) { - release(t); + if (DISCARD_GUARD.getAndIncrement(this) != 0) { + return; + } + + int missed = 1; + + for (; ; ) { + while (!queue.isEmpty()) { + T t = queue.poll(); + if (t != null) { + release(t); + } + } + while (!priorityQueue.isEmpty()) { + T t = priorityQueue.poll(); + if (t != null) { + release(t); + } + } + + missed = DISCARD_GUARD.addAndGet(this, -missed); + if (missed == 0) { + break; } } } @@ -387,14 +441,18 @@ public long downstreamCount() { @Override public boolean hasDownstreams() { - return actual != null; + return hasDownstream; } void release(T t) { if (t instanceof ReferenceCounted) { ReferenceCounted refCounted = (ReferenceCounted) t; if (refCounted.refCnt() > 0) { - refCounted.release(); + try { + refCounted.release(); + } catch (Throwable ex) { + // no ops + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java index eb8a1aa11..64a7d4422 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java @@ -9,10 +9,8 @@ import reactor.util.annotation.Nullable; /** - * Represents an empty publisher which only calls onSubscribe and onComplete. - * - *

This Publisher is effectively stateless and only a single instance exists. Use the {@link - * #instance()} method to obtain a properly type-parametrized view of it. + * Represents an empty publisher which only calls onSubscribe and onComplete and allows only a + * single subscriber. * * @see Reactive-Streams-Commons */ diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java index af4c8768b..c5b06e086 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java @@ -99,6 +99,7 @@ public static UnicastMonoProcessor create(MonoLifecycleHandler lifecyc UnicastMonoProcessor.class, Subscription.class, "subscription"); CoreSubscriber actual; + boolean hasDownstream = false; Throwable error; O value; @@ -185,7 +186,7 @@ private void complete(O v) { if (state == HAS_REQUEST_NO_RESULT) { if (STATE.compareAndSet(this, HAS_REQUEST_NO_RESULT, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; - actual = null; + hasDownstream = false; value = null; lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, v, null); a.onNext(v); @@ -222,7 +223,7 @@ private void complete() { if (state == HAS_REQUEST_NO_RESULT || state == NO_REQUEST_NO_RESULT) { if (STATE.compareAndSet(this, state, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; - actual = null; + hasDownstream = false; lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, null, null); a.onComplete(); return; @@ -256,7 +257,7 @@ private void complete(Throwable e) { if (state == HAS_REQUEST_NO_RESULT || state == NO_REQUEST_NO_RESULT) { if (STATE.compareAndSet(this, state, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; - actual = null; + hasDownstream = false; lifecycleHandler.doOnTerminal(SignalType.ON_ERROR, null, e); a.onError(e); return; @@ -278,6 +279,7 @@ public void subscribe(CoreSubscriber actual) { lh.doOnSubscribe(); + this.hasDownstream = true; this.actual = actual; int state = this.state; @@ -303,7 +305,7 @@ public void subscribe(CoreSubscriber actual) { // no value // e.g. [onError / onComplete / dispose] only if (state == NO_REQUEST_HAS_RESULT && this.value == null) { - this.actual = null; + this.hasDownstream = false; Throwable e = this.error; // barrier to flush changes STATE.set(this, HAS_REQUEST_HAS_RESULT); @@ -340,7 +342,7 @@ public final void request(long n) { if (STATE.compareAndSet(this, NO_REQUEST_HAS_RESULT, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; final O v = value; - actual = null; + hasDownstream = false; value = null; lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, v, null); a.onNext(v); @@ -360,7 +362,7 @@ public final void cancel() { if (STATE.getAndSet(this, CANCELLED) <= HAS_REQUEST_NO_RESULT) { Operators.onDiscard(value, currentContext()); value = null; - actual = null; + hasDownstream = false; lifecycleHandler.doOnTerminal(SignalType.CANCEL, null, null); final Subscription s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription()); if (s != null && s != Operators.cancelledSubscription()) { @@ -502,6 +504,6 @@ public Object scanUnsafe(Attr key) { * @return true if any {@link Subscriber} is actively subscribed */ public final boolean hasDownstream() { - return state > NO_SUBSCRIBER_HAS_RESULT && actual != null; + return state > NO_SUBSCRIBER_HAS_RESULT && hasDownstream; } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/MissingLeaseException.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java similarity index 60% rename from rsocket-core/src/main/java/io/rsocket/exceptions/MissingLeaseException.java rename to rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java index 4bd6ffb99..734d16d07 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/MissingLeaseException.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -1,6 +1,21 @@ -package io.rsocket.exceptions; +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.lease; -import io.rsocket.lease.Lease; +import io.rsocket.exceptions.RejectedException; import java.util.Objects; import javax.annotation.Nonnull; import javax.annotation.Nullable; diff --git a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java index ca2111e87..dd4247090 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java @@ -18,7 +18,6 @@ import io.netty.buffer.ByteBuf; import io.rsocket.Availability; -import io.rsocket.exceptions.MissingLeaseException; import io.rsocket.frame.LeaseFrameFlyweight; import java.util.function.Consumer; import reactor.core.Disposable; diff --git a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java index c517a55c4..5ca745ee7 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java @@ -19,7 +19,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.Availability; -import io.rsocket.exceptions.MissingLeaseException; import io.rsocket.frame.LeaseFrameFlyweight; import java.util.Optional; import java.util.function.Consumer; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java index 2743f604d..e78e87629 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -72,7 +72,8 @@ public enum WellKnownMimeType { APPLICATION_CLOUDEVENTS_JSON("application/cloudevents+json", (byte) 0x28), // ... reserved for future use ... - + MESSAGE_RSOCKET_MIMETYPE("message/x.rsocket.mime-type.v0", (byte) 0x7A), + MESSAGE_RSOCKET_ACCEPT_MIMETYPES("message/x.rsocket.accept-mime-types.v0", (byte) 0x7B), MESSAGE_RSOCKET_AUTHENTICATION("message/x.rsocket.authentication.v0", (byte) 0x7C), MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java index f0f5cf54e..27bf4d1da 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java @@ -5,7 +5,6 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; -import io.rsocket.buffer.TupleByteBuf; import io.rsocket.util.CharByteBufUtil; public class AuthMetadataFlyweight { @@ -49,7 +48,7 @@ public static ByteBuf encodeMetadata( ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength); - return TupleByteBuf.of(allocator, headerBuffer, metadata); + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); } /** @@ -76,7 +75,7 @@ public static ByteBuf encodeMetadata( .buffer(capacity, capacity) .writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); - return TupleByteBuf.of(allocator, headerBuffer, metadata); + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); } /** diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java new file mode 100644 index 000000000..cf911b954 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; + +public class InitializingInterceptorRegistry extends InterceptorRegistry { + + public DuplexConnection initConnection( + DuplexConnectionInterceptor.Type type, DuplexConnection connection) { + for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { + connection = interceptor.apply(type, connection); + } + return connection; + } + + public RSocket initRequester(RSocket rsocket) { + for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public RSocket initResponder(RSocket rsocket) { + for (RSocketInterceptor interceptor : getResponderInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public SocketAcceptor initSocketAcceptor(SocketAcceptor acceptor) { + for (SocketAcceptorInterceptor interceptor : getSocketAcceptorInterceptors()) { + acceptor = interceptor.apply(acceptor); + } + return acceptor; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java new file mode 100644 index 000000000..f9ee151a8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +public class InterceptorRegistry { + private List connectionInterceptors = new ArrayList<>(); + private List requesterInteceptors = new ArrayList<>(); + private List responderInterceptors = new ArrayList<>(); + private List socketAcceptorInterceptors = new ArrayList<>(); + + public InterceptorRegistry forConnection(DuplexConnectionInterceptor interceptor) { + connectionInterceptors.add(interceptor); + return this; + } + + public InterceptorRegistry forConnection(Consumer> consumer) { + consumer.accept(connectionInterceptors); + return this; + } + + public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { + requesterInteceptors.add(interceptor); + return this; + } + + public InterceptorRegistry forRequester(Consumer> consumer) { + consumer.accept(requesterInteceptors); + return this; + } + + public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { + responderInterceptors.add(interceptor); + return this; + } + + public InterceptorRegistry forResponder(Consumer> consumer) { + consumer.accept(responderInterceptors); + return this; + } + + public InterceptorRegistry forSocketAcceptor(SocketAcceptorInterceptor interceptor) { + socketAcceptorInterceptors.add(interceptor); + return this; + } + + public InterceptorRegistry forSocketAcceptor(Consumer> consumer) { + consumer.accept(socketAcceptorInterceptors); + return this; + } + + List getConnectionInterceptors() { + return connectionInterceptors; + } + + List getRequesterInteceptors() { + return requesterInteceptors; + } + + List getResponderInterceptors() { + return responderInterceptors; + } + + List getSocketAcceptorInterceptors() { + return socketAcceptorInterceptors; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java deleted file mode 100644 index e3a19367c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -import io.rsocket.DuplexConnection; -import io.rsocket.RSocket; -import io.rsocket.SocketAcceptor; -import java.util.ArrayList; -import java.util.List; - -public class PluginRegistry { - private List connections = new ArrayList<>(); - private List requesters = new ArrayList<>(); - private List responders = new ArrayList<>(); - private List socketAcceptorInterceptors = new ArrayList<>(); - - public PluginRegistry() {} - - public PluginRegistry(PluginRegistry defaults) { - this.connections.addAll(defaults.connections); - this.requesters.addAll(defaults.requesters); - this.responders.addAll(defaults.responders); - } - - public void addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - connections.add(interceptor); - } - - /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ - @Deprecated - public void addClientPlugin(RSocketInterceptor interceptor) { - addRequesterPlugin(interceptor); - } - - public void addRequesterPlugin(RSocketInterceptor interceptor) { - requesters.add(interceptor); - } - - /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ - @Deprecated - public void addServerPlugin(RSocketInterceptor interceptor) { - addResponderPlugin(interceptor); - } - - public void addResponderPlugin(RSocketInterceptor interceptor) { - responders.add(interceptor); - } - - public void addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - socketAcceptorInterceptors.add(interceptor); - } - - /** Deprecated. Use {@link #applyRequester(RSocket)} instead */ - @Deprecated - public RSocket applyClient(RSocket rSocket) { - return applyRequester(rSocket); - } - - public RSocket applyRequester(RSocket rSocket) { - for (RSocketInterceptor i : requesters) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - /** Deprecated. Use {@link #applyResponder(RSocket)} instead */ - @Deprecated - public RSocket applyServer(RSocket rSocket) { - return applyResponder(rSocket); - } - - public RSocket applyResponder(RSocket rSocket) { - for (RSocketInterceptor i : responders) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - public SocketAcceptor applySocketAcceptorInterceptor(SocketAcceptor acceptor) { - for (SocketAcceptorInterceptor i : socketAcceptorInterceptors) { - acceptor = i.apply(acceptor); - } - - return acceptor; - } - - public DuplexConnection applyConnection( - DuplexConnectionInterceptor.Type type, DuplexConnection connection) { - for (DuplexConnectionInterceptor i : connections) { - connection = i.apply(type, connection); - } - - return connection; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java b/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java deleted file mode 100644 index 1ac147687..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -/** JVM wide plugins for RSocket */ -public class Plugins { - private static PluginRegistry DEFAULT = new PluginRegistry(); - - private Plugins() {} - - public static void interceptConnection(DuplexConnectionInterceptor interceptor) { - DEFAULT.addConnectionPlugin(interceptor); - } - - public static void interceptClient(RSocketInterceptor interceptor) { - DEFAULT.addClientPlugin(interceptor); - } - - public static void interceptServer(RSocketInterceptor interceptor) { - DEFAULT.addServerPlugin(interceptor); - } - - public static PluginRegistry defaultPlugins() { - return DEFAULT; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java index c9201ca5b..0cb9d92d2 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java index b347642e3..01b6dfeae 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,10 +26,11 @@ import io.rsocket.internal.ClientServerInputMultiplexer; import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; public class ClientRSocketSession implements RSocketSession> { private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); @@ -41,13 +42,12 @@ public class ClientRSocketSession implements RSocketSession resumeStrategy, + Retry retry, ResumableFramesStore resumableFramesStore, Duration resumeStreamTimeout, boolean cleanupStoreOnKeepAlive) { - this.allocator = allocator; + this.allocator = duplexConnection.alloc(); this.resumableConnection = new ResumableDuplexConnection( "client", @@ -64,24 +64,13 @@ public ClientRSocketSession( .flatMap( err -> { logger.debug("Client session connection error. Starting new connection"); - ResumeStrategy reconnectOnError = resumeStrategy.get(); - ClientResume clientResume = new ClientResume(resumeSessionDuration, resumeToken); AtomicBoolean once = new AtomicBoolean(); return newConnection .delaySubscription( once.compareAndSet(false, true) - ? reconnectOnError.apply(clientResume, err) + ? retry.generateCompanion(Flux.just(new RetrySignal(err))) : Mono.empty()) - .retryWhen( - errors -> - errors - .doOnNext( - retryErr -> - logger.debug("Resumption reconnection error", retryErr)) - .flatMap( - retryErr -> - Mono.from(reconnectOnError.apply(clientResume, retryErr)) - .doOnNext(v -> logger.debug("Retrying with: {}", v)))) + .retryWhen(retry) .timeout(resumeSessionDuration); }) .map(ClientServerInputMultiplexer::new) @@ -178,4 +167,28 @@ private static long remotePos(ByteBuf resumeOkFrame) { private static ConnectionErrorException errorFrameThrowable(long impliedPos) { return new ConnectionErrorException("resumption_server_pos=[" + impliedPos + "]"); } + + private static class RetrySignal implements Retry.RetrySignal { + + private final Throwable ex; + + RetrySignal(Throwable ex) { + this.ex = ex; + } + + @Override + public long totalRetries() { + return 0; + } + + @Override + public long totalRetriesInARow() { + return 0; + } + + @Override + public Throwable failure() { + return ex; + } + } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java index b46ac864b..461be02d2 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java @@ -21,7 +21,13 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; +/** + * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#backoff(long, Duration)} to {@link + * io.rsocket.core.Resume#retry(Retry)}. + */ +@Deprecated public class ExponentialBackoffResumeStrategy implements ResumeStrategy { private volatile Duration next; private final Duration firstBackoff; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java index abfefe0b1..bd447c8a9 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java @@ -19,7 +19,13 @@ import java.time.Duration; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; +/** + * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#fixedDelay(long, Duration)} to {@link + * io.rsocket.core.Resume#retry(Retry)}. + */ +@Deprecated public class PeriodicResumeStrategy implements ResumeStrategy { private final Duration interval; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java index 49401d560..980de2de1 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.resume; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; import io.rsocket.frame.FrameHeaderFlyweight; @@ -105,6 +106,11 @@ public ResumableDuplexConnection( reconnect(duplexConnection); } + @Override + public ByteBufAllocator alloc() { + return curConnection.alloc(); + } + public void disconnect() { DuplexConnection c = this.curConnection; if (c != null) { @@ -217,6 +223,10 @@ public boolean isDisposed() { } private void sendFrame(ByteBuf f) { + if (disposed.get()) { + f.release(); + return; + } /*resuming from store so no need to save again*/ if (state != State.RESUME && isResumableFrame(f)) { resumeSaveFrames.onNext(f); diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java index 903431192..d9dec9f54 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,12 @@ import java.util.function.BiFunction; import org.reactivestreams.Publisher; +import reactor.util.retry.Retry; +/** + * @deprecated as of 1.0 RC7 in favor of using {@link io.rsocket.core.Resume#retry(Retry)} via + * {@link io.rsocket.core.RSocketConnector} or {@link io.rsocket.core.RSocketServer}. + */ +@Deprecated @FunctionalInterface public interface ResumeStrategy extends BiFunction> {} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java index 1a0605497..5d55559cc 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -43,13 +43,12 @@ public class ServerRSocketSession implements RSocketSession { public ServerRSocketSession( DuplexConnection duplexConnection, - ByteBufAllocator allocator, Duration resumeSessionDuration, Duration resumeStreamTimeout, Function resumeStoreFactory, ByteBuf resumeToken, boolean cleanupStoreOnKeepAlive) { - this.allocator = allocator; + this.allocator = duplexConnection.alloc(); this.resumeToken = resumeToken; this.resumableConnection = new ResumableDuplexConnection( diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java b/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java deleted file mode 100644 index ec3d4ab3c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Optional; -import java.util.ServiceLoader; - -/** Maps a {@link URI} to a {@link ClientTransport} or {@link ServerTransport}. */ -public interface UriHandler { - - /** - * Load all registered instances of {@code UriHandler}. - * - * @return all registered instances of {@code UriHandler} - */ - static ServiceLoader loadServices() { - return ServiceLoader.load(UriHandler.class); - } - - /** - * Returns an implementation of {@link ClientTransport} unambiguously mapped to a {@link URI}, - * otherwise {@link Optional#EMPTY}. - * - * @param uri the uri to map - * @return an implementation of {@link ClientTransport} unambiguously mapped to a {@link URI}, * - * otherwise {@link Optional#EMPTY} - * @throws NullPointerException if {@code uri} is {@code null} - */ - Optional buildClient(URI uri); - - /** - * Returns an implementation of {@link ServerTransport} unambiguously mapped to a {@link URI}, - * otherwise {@link Optional#EMPTY}. - * - * @param uri the uri to map - * @return an implementation of {@link ServerTransport} unambiguously mapped to a {@link URI}, * - * otherwise {@link Optional#EMPTY} - * @throws NullPointerException if {@code uri} is {@code null} - */ - Optional buildServer(URI uri); -} diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java b/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java deleted file mode 100644 index 204c5d1ea..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static io.rsocket.uri.UriHandler.loadServices; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.ServiceLoader; -import reactor.core.publisher.Mono; - -/** - * Registry for looking up transports by URI. - * - *

Uses the Jar Services mechanism with services defined by {@link UriHandler}. - */ -public class UriTransportRegistry { - private static final ClientTransport FAILED_CLIENT_LOOKUP = - (mtu) -> Mono.error(new UnsupportedOperationException()); - private static final ServerTransport FAILED_SERVER_LOOKUP = - (acceptor, mtu) -> Mono.error(new UnsupportedOperationException()); - - private List handlers; - - public UriTransportRegistry(ServiceLoader services) { - handlers = new ArrayList<>(); - services.forEach(handlers::add); - } - - public static UriTransportRegistry fromServices() { - ServiceLoader services = loadServices(); - - return new UriTransportRegistry(services); - } - - public static ClientTransport clientForUri(String uri) { - return UriTransportRegistry.fromServices().findClient(uri); - } - - public static ServerTransport serverForUri(String uri) { - return UriTransportRegistry.fromServices().findServer(uri); - } - - private ClientTransport findClient(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildClient(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_CLIENT_LOOKUP; - } - - private ServerTransport findServer(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildServer(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_SERVER_LOOKUP; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index b91cf8ac6..f5d747f7f 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; import io.rsocket.Payload; @@ -112,9 +113,10 @@ public static Payload create(ByteBuf data) { public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { ByteBufPayload payload = RECYCLER.get(); - payload.setRefCnt(1); payload.data = data; payload.metadata = metadata; + // unsure data and metadata is set before refCnt change + payload.setRefCnt(1); return payload; } @@ -126,26 +128,31 @@ public static Payload create(Payload payload) { @Override public boolean hasMetadata() { + ensureAccessible(); return metadata != null; } @Override public ByteBuf sliceMetadata() { + ensureAccessible(); return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); } @Override public ByteBuf data() { + ensureAccessible(); return data; } @Override public ByteBuf metadata() { + ensureAccessible(); return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; } @Override public ByteBuf sliceData() { + ensureAccessible(); return data.slice(); } @@ -163,6 +170,7 @@ public ByteBufPayload retain(int increment) { @Override public ByteBufPayload touch() { + ensureAccessible(); data.touch(); if (metadata != null) { metadata.touch(); @@ -172,6 +180,7 @@ public ByteBufPayload touch() { @Override public ByteBufPayload touch(Object hint) { + ensureAccessible(); data.touch(hint); if (metadata != null) { metadata.touch(hint); @@ -189,4 +198,22 @@ protected void deallocate() { } handle.recycle(this); } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link ByteBufPayload#ensureAccessible()} to try to guard against using the + * buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java index e011d2a6f..328fb8435 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -99,7 +99,7 @@ private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { } /** - * Encode a {@link char[]} in UTF-8 and write it + * Encode a {@code char[]} in UTF-8 and write it * into {@link ByteBuf}. * *

This method returns the actual number of bytes written. @@ -109,9 +109,8 @@ public static int writeUtf8(ByteBuf buf, char[] seq) { } /** - * Equivalent to {@link #writeUtf8(ByteBuf, char[]) - * writeUtf8(buf, seq.subSequence(start, end), reserveBytes)} but avoids subsequence object - * allocation if possible. + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end), + * reserveBytes)} but avoids subsequence object allocation if possible. * * @return actual number of bytes written */ @@ -182,7 +181,10 @@ public static char[] readUtf8(ByteBuf byteBuf, int length) { char[] ca = new char[en]; CharBuffer charBuffer = CharBuffer.wrap(ca); - ByteBuffer byteBuffer = byteBuf.internalNioBuffer(byteBuf.readerIndex(), length); + ByteBuffer byteBuffer = + byteBuf.nioBufferCount() == 1 + ? byteBuf.internalNioBuffer(byteBuf.readerIndex(), length) + : byteBuf.nioBuffer(byteBuf.readerIndex(), length); byteBuffer.mark(); try { CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true); diff --git a/rsocket-core/src/main/java/io/rsocket/util/ConnectionUtils.java b/rsocket-core/src/main/java/io/rsocket/util/ConnectionUtils.java deleted file mode 100644 index dd8bbf907..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/ConnectionUtils.java +++ /dev/null @@ -1,17 +0,0 @@ -package io.rsocket.util; - -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.internal.ClientServerInputMultiplexer; -import reactor.core.publisher.Mono; - -public class ConnectionUtils { - - public static Mono sendError( - ByteBufAllocator allocator, ClientServerInputMultiplexer multiplexer, Exception exception) { - return multiplexer - .asSetupConnection() - .sendOne(ErrorFrameFlyweight.encode(allocator, 0, exception)) - .onErrorResume(err -> Mono.empty()); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java b/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java index fa19553a7..2f5d1da4b 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java @@ -17,6 +17,7 @@ package io.rsocket.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -44,6 +45,11 @@ public double availability() { return connection.availability(); } + @Override + public ByteBufAllocator alloc() { + return connection.alloc(); + } + @Override public Mono onClose() { return connection.onClose(); diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java deleted file mode 100644 index a739f2e67..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static io.rsocket.frame.FrameType.*; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.*; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import io.rsocket.util.MultiSubscriberRSocket; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import org.assertj.core.api.Assertions; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.BaseSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; - -public class RSocketRequesterTest { - - @Rule public final ClientSocketRule rule = new ClientSocketRule(); - - @Test(timeout = 2_000) - public void testInvalidFrameOnStream0() { - rule.connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 0, 10)); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(IllegalStateException.class))); - } - - @Test(timeout = 2_000) - public void testStreamInitialN() { - Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); - - BaseSubscriber subscriber = - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - // don't request here - // subscription.request(3); - } - }; - stream.subscribe(subscriber); - - subscriber.request(5); - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat("sent frame count", sent.size(), is(1)); - - ByteBuf f = sent.get(0); - - assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); - assertThat("initial request n", RequestStreamFrameFlyweight.initialRequestN(f), is(5)); - } - - @Test(timeout = 2_000) - public void testHandleSetupException() { - rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("boom"))); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(RejectedSetupException.class))); - } - - @Test(timeout = 2_000) - public void testHandleApplicationException() { - rule.connection.clearSendReceiveBuffers(); - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber responseSub = TestSubscriber.create(); - response.subscribe(responseSub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, new ApplicationErrorException("error"))); - - verify(responseSub).onError(any(ApplicationErrorException.class)); - } - - @Test(timeout = 2_000) - public void testHandleValidFrame() { - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNext( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); - - verify(sub).onComplete(); - } - - @Test(timeout = 2_000) - public void testRequestReplyWithCancel() { - Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - - try { - response.block(Duration.ofMillis(100)); - } catch (IllegalStateException ise) { - } - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat( - "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); - assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); - } - - @Test(timeout = 2_000) - public void testRequestReplyErrorOnSend() { - rule.connection.setAvailability(0); // Fails send - Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber responseSub = TestSubscriber.create(10); - response.subscribe(responseSub); - - this.rule.assertNoConnectionErrors(); - - verify(responseSub).onSubscribe(any(Subscription.class)); - - // TODO this should get the error reported through the response subscription - // verify(responseSub).onError(any(RuntimeException.class)); - } - - @Test(timeout = 2_000) - public void testLazyRequestResponse() { - Publisher response = - new MultiSubscriberRSocket(rule.socket).requestResponse(EmptyPayload.INSTANCE); - int streamId = sendRequestResponse(response); - rule.connection.clearSendReceiveBuffers(); - int streamId2 = sendRequestResponse(response); - assertThat("Stream ID reused.", streamId2, not(equalTo(streamId))); - } - - @Test - public void testChannelRequestCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - Flux request = Flux.never().doOnCancel(cancelled::onComplete); - rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, - Flux.error(new IllegalStateException("Channel request not cancelled")) - .delaySubscription(Duration.ofSeconds(1))) - .blockFirst(); - } - - @Test - public void testChannelRequestServerSideCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - UnicastProcessor request = UnicastProcessor.create(); - request.onNext(EmptyPayload.INSTANCE); - rule.socket.requestChannel(request).subscribe(cancelled); - int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); - rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeComplete(ByteBufAllocator.DEFAULT, streamId)); - Flux.first( - cancelled, - Flux.error(new IllegalStateException("Channel request not cancelled")) - .delaySubscription(Duration.ofSeconds(1))) - .blockFirst(); - - Assertions.assertThat(request.isDisposed()).isTrue(); - } - - public int sendRequestResponse(Publisher response) { - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNextComplete( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); - verify(sub).onNext(any(Payload.class)); - verify(sub).onComplete(); - return streamId; - } - - public static class ClientSocketRule extends AbstractSocketRule { - @Override - protected RSocketRequester newRSocket() { - return new RSocketRequester( - ByteBufAllocator.DEFAULT, - connection, - DefaultPayload::create, - throwable -> errors.add(throwable), - StreamIdSupplier.clientSupplier(), - 0, - 0, - null, - RequesterLeaseHandler.None); - } - - public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); - List framesFound = new ArrayList<>(); - for (ByteBuf frame : connection.getSent()) { - FrameType frameType = frameType(frame); - if (frameType == expectedFrameType) { - return FrameHeaderFlyweight.streamId(frame); - } - framesFound.add(frameType); - } - throw new AssertionError( - "No frames sent with frame type: " - + expectedFrameType - + ", frames found: " - + framesFound); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java deleted file mode 100644 index b6281414d..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.frame.*; -import io.rsocket.lease.ResponderLeaseHandler; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.util.Collection; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.Mono; - -public class RSocketResponderTest { - - @Rule public final ServerSocketRule rule = new ServerSocketRule(); - - @Test(timeout = 2000) - @Ignore - public void testHandleKeepAlive() throws Exception { - rule.connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER)); - ByteBuf sent = rule.connection.awaitSend(); - assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); - /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ - assertThat( - "Unexpected keep-alive frame respond flag.", - KeepAliveFrameFlyweight.respondFlag(sent), - is(false)); - } - - @Test(timeout = 2000) - @Ignore - public void testHandleResponseFrameNoError() throws Exception { - final int streamId = 4; - rule.connection.clearSendReceiveBuffers(); - - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - Collection> sendSubscribers = rule.connection.getSendSubscribers(); - assertThat("Request not sent.", sendSubscribers, hasSize(1)); - assertThat("Unexpected error.", rule.errors, is(empty())); - Subscriber sendSub = sendSubscribers.iterator().next(); - assertThat( - "Unexpected frame sent.", - frameType(rule.connection.awaitSend()), - anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); - } - - @Test(timeout = 2000) - @Ignore - public void testHandlerEmitsError() throws Exception { - final int streamId = 4; - rule.sendRequest(streamId, FrameType.REQUEST_STREAM); - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat( - "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); - } - - @Test(timeout = 2_0000) - public void testCancel() { - final int streamId = 4; - final AtomicBoolean cancelled = new AtomicBoolean(); - rule.setAcceptingSocket( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.never().doOnCancel(() -> cancelled.set(true)); - } - }); - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - - rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); - - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); - } - - public static class ServerSocketRule extends AbstractSocketRule { - - private RSocket acceptingSocket; - - @Override - protected void init() { - acceptingSocket = - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - }; - super.init(); - } - - public void setAcceptingSocket(RSocket acceptingSocket) { - this.acceptingSocket = acceptingSocket; - connection = new TestDuplexConnection(); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - super.init(); - } - - public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { - this.acceptingSocket = acceptingSocket; - connection = new TestDuplexConnection(); - connection.setInitialSendRequestN(prefetch); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - super.init(); - } - - @Override - protected RSocketResponder newRSocket() { - return new RSocketResponder( - ByteBufAllocator.DEFAULT, - connection, - acceptingSocket, - DefaultPayload::create, - throwable -> errors.add(throwable), - ResponderLeaseHandler.None); - } - - private void sendRequest(int streamId, FrameType frameType) { - ByteBuf request; - - switch (frameType) { - case REQUEST_CHANNEL: - request = - RequestChannelFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, false, 1, EmptyPayload.INSTANCE); - break; - case REQUEST_STREAM: - request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, 1, EmptyPayload.INSTANCE); - break; - case REQUEST_RESPONSE: - request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, EmptyPayload.INSTANCE); - break; - default: - throw new IllegalArgumentException("unsupported type: " + frameType); - } - - connection.addToReceivedBuffer(request); - connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, 2)); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java deleted file mode 100644 index 80865ec47..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.exceptions.CustomRSocketException; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; -import io.rsocket.test.util.LocalDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.util.ArrayList; -import org.hamcrest.MatcherAssert; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; -import org.mockito.ArgumentCaptor; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.DirectProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -public class RSocketTest { - - @Rule public final SocketRule rule = new SocketRule(); - - public static void assertError(String s, String mode, ArrayList errors) { - for (Throwable t : errors) { - if (t.toString().equals(s)) { - return; - } - } - - Assert.fail("Expected " + mode + " connection error: " + s + " other errors " + errors.size()); - } - - @Test(timeout = 2_000) - public void testRequestReplyNoError() { - StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) - .expectNextCount(1) - .expectComplete() - .verify(); - } - - @Test(timeout = 2000) - public void testHandlerEmitsError() { - rule.setRequestAcceptor( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.error(new NullPointerException("Deliberate exception.")); - } - }); - Subscriber subscriber = TestSubscriber.create(); - rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); - verify(subscriber).onError(any(ApplicationErrorException.class)); - - // Client sees error through normal API - rule.assertNoClientErrors(); - - rule.assertServerError("java.lang.NullPointerException: Deliberate exception."); - } - - @Test(timeout = 2000) - public void testHandlerEmitsCustomError() { - rule.setRequestAcceptor( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.error( - new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); - } - }); - Subscriber subscriber = TestSubscriber.create(); - rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); - ArgumentCaptor customRSocketExceptionArgumentCaptor = - ArgumentCaptor.forClass(CustomRSocketException.class); - verify(subscriber).onError(customRSocketExceptionArgumentCaptor.capture()); - - Assert.assertEquals( - "Deliberate Custom exception.", - customRSocketExceptionArgumentCaptor.getValue().getMessage()); - Assert.assertEquals(0x00000501, customRSocketExceptionArgumentCaptor.getValue().errorCode()); - - // Client sees error through normal API - rule.assertNoClientErrors(); - - rule.assertServerError( - "io.rsocket.exceptions.CustomRSocketException: Deliberate Custom exception."); - } - - @Test(timeout = 2000) - public void testStream() throws Exception { - Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); - StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); - } - - @Test(timeout = 2000) - public void testChannel() throws Exception { - Flux requests = - Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); - Flux responses = rule.crs.requestChannel(requests); - StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); - } - - public static class SocketRule extends ExternalResource { - - DirectProcessor serverProcessor; - DirectProcessor clientProcessor; - private RSocketRequester crs; - - @SuppressWarnings("unused") - private RSocketResponder srs; - - private RSocket requestAcceptor; - private ArrayList clientErrors = new ArrayList<>(); - private ArrayList serverErrors = new ArrayList<>(); - - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - init(); - base.evaluate(); - } - }; - } - - protected void init() { - serverProcessor = DirectProcessor.create(); - clientProcessor = DirectProcessor.create(); - - LocalDuplexConnection serverConnection = - new LocalDuplexConnection("server", clientProcessor, serverProcessor); - LocalDuplexConnection clientConnection = - new LocalDuplexConnection("client", serverProcessor, clientProcessor); - - requestAcceptor = - null != requestAcceptor - ? requestAcceptor - : new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 10) - .map( - i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); - } - - @Override - public Flux requestChannel(Publisher payloads) { - Flux.from(payloads) - .map( - payload -> - DefaultPayload.create("server got -> [" + payload.toString() + "]")) - .subscribe(); - - return Flux.range(1, 10) - .map( - payload -> - DefaultPayload.create("server got -> [" + payload.toString() + "]")); - } - }; - - srs = - new RSocketResponder( - ByteBufAllocator.DEFAULT, - serverConnection, - requestAcceptor, - DefaultPayload::create, - throwable -> serverErrors.add(throwable), - ResponderLeaseHandler.None); - - crs = - new RSocketRequester( - ByteBufAllocator.DEFAULT, - clientConnection, - DefaultPayload::create, - throwable -> clientErrors.add(throwable), - StreamIdSupplier.clientSupplier(), - 0, - 0, - null, - RequesterLeaseHandler.None); - } - - public void setRequestAcceptor(RSocket requestAcceptor) { - this.requestAcceptor = requestAcceptor; - init(); - } - - public void assertNoErrors() { - assertNoClientErrors(); - assertNoServerErrors(); - } - - public void assertNoClientErrors() { - MatcherAssert.assertThat( - "Unexpected error on the client connection.", clientErrors, is(empty())); - } - - public void assertNoServerErrors() { - MatcherAssert.assertThat( - "Unexpected error on the server connection.", serverErrors, is(empty())); - } - - public void assertClientError(String s) { - assertError(s, "client", this.clientErrors); - } - - public void assertServerError(String s) { - assertError(s, "server", this.serverErrors); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..800e5d678 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,153 @@ +package io.rsocket.buffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { + this.delegate = delegate; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + Assertions.assertThat(tracker) + .allSatisfy( + buf -> + Assertions.assertThat(buf) + .matches(bb -> bb.refCnt() == 0, "buffer should be released")); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java similarity index 78% rename from rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java rename to rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index 22568bfcc..20972a0d3 100644 --- a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -14,8 +14,11 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import java.util.concurrent.ConcurrentLinkedQueue; @@ -31,13 +34,15 @@ public abstract class AbstractSocketRule extends ExternalReso protected Subscriber connectSub; protected T socket; protected ConcurrentLinkedQueue errors; + protected LeaksTrackingByteBufAllocator allocator; @Override public Statement apply(final Statement base, Description description) { return new Statement() { @Override public void evaluate() throws Throwable { - connection = new TestDuplexConnection(); + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); connectSub = TestSubscriber.create(); errors = new ConcurrentLinkedQueue<>(); init(); @@ -57,4 +62,12 @@ public void assertNoConnectionErrors() { Assert.fail("No connection errors expected: " + errors.peek().toString()); } } + + public ByteBufAllocator alloc() { + return allocator; + } + + public void assertHasNoLeaks() { + allocator.assertHasNoLeaks(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/ConnectionSetupPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java similarity index 89% rename from rsocket-core/src/test/java/io/rsocket/ConnectionSetupPayloadTest.java rename to rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java index 16e0f2ec7..ea3142d25 100644 --- a/rsocket-core/src/test/java/io/rsocket/ConnectionSetupPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java @@ -1,4 +1,4 @@ -package io.rsocket; +package io.rsocket.core; import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -6,6 +6,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; import io.rsocket.frame.SetupFrameFlyweight; import io.rsocket.util.DefaultPayload; import org.junit.jupiter.api.Test; @@ -24,7 +26,7 @@ void testSetupPayloadWithDataMetadata() { boolean leaseEnabled = true; ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(frame); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); assertTrue(setupPayload.willClientHonorLease()); assertEquals(KEEP_ALIVE_INTERVAL, setupPayload.keepAliveInterval()); @@ -46,7 +48,7 @@ void testSetupPayloadWithNoMetadata() { boolean leaseEnabled = false; ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(frame); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); assertFalse(setupPayload.willClientHonorLease()); assertFalse(setupPayload.hasMetadata()); @@ -64,7 +66,7 @@ void testSetupPayloadWithEmptyMetadata() { boolean leaseEnabled = false; ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(frame); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); assertFalse(setupPayload.willClientHonorLease()); assertTrue(setupPayload.hasMetadata()); diff --git a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java similarity index 88% rename from rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java rename to rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index b275ccc33..e8f3f4190 100644 --- a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.keepalive.KeepAliveHandler.DefaultKeepAliveHandler; import static io.rsocket.keepalive.KeepAliveHandler.ResumableKeepAliveHandler; @@ -22,6 +22,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameType; @@ -51,24 +53,28 @@ public class KeepAliveTest { private ResumableRSocketState resumableRequesterState; static RSocketState requester(int tickPeriod, int timeout) { - TestDuplexConnection connection = new TestDuplexConnection(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); Errors errors = new Errors(); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, connection, DefaultPayload::create, errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new DefaultKeepAliveHandler(connection), RequesterLeaseHandler.None); - return new RSocketState(rSocket, errors, connection); + return new RSocketState(rSocket, errors, allocator, connection); } static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { - TestDuplexConnection connection = new TestDuplexConnection(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); ResumableDuplexConnection resumableConnection = new ResumableDuplexConnection( "test", @@ -80,16 +86,16 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { Errors errors = new Errors(); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, resumableConnection, DefaultPayload::create, errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new ResumableKeepAliveHandler(resumableConnection), RequesterLeaseHandler.None); - return new ResumableRSocketState(rSocket, errors, connection, resumableConnection); + return new ResumableRSocketState(rSocket, errors, connection, resumableConnection, allocator); } @BeforeEach @@ -191,7 +197,7 @@ void resumableRequesterKeepAlivesAfterReconnect() { resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); resumableDuplexConnection.disconnect(); - TestDuplexConnection newTestConnection = new TestDuplexConnection(); + TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); resumableDuplexConnection.reconnect(newTestConnection); resumableDuplexConnection.resume(0, 0, ignored -> Mono.empty()); @@ -241,11 +247,17 @@ static class RSocketState { private final RSocket rSocket; private final Errors errors; private final TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator; - public RSocketState(RSocket rSocket, Errors errors, TestDuplexConnection connection) { + public RSocketState( + RSocket rSocket, + Errors errors, + LeaksTrackingByteBufAllocator allocator, + TestDuplexConnection connection) { this.rSocket = rSocket; this.errors = errors; this.connection = connection; + this.allocator = allocator; } public TestDuplexConnection connection() { @@ -259,6 +271,10 @@ public RSocket rSocket() { public Errors errors() { return errors; } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } } static class ResumableRSocketState { @@ -266,16 +282,19 @@ static class ResumableRSocketState { private final Errors errors; private final TestDuplexConnection connection; private final ResumableDuplexConnection resumableDuplexConnection; + private final LeaksTrackingByteBufAllocator allocator; public ResumableRSocketState( RSocket rSocket, Errors errors, TestDuplexConnection connection, - ResumableDuplexConnection resumableDuplexConnection) { + ResumableDuplexConnection resumableDuplexConnection, + LeaksTrackingByteBufAllocator allocator) { this.rSocket = rSocket; this.errors = errors; this.connection = connection; this.resumableDuplexConnection = resumableDuplexConnection; + this.allocator = allocator; } public TestDuplexConnection connection() { @@ -293,6 +312,10 @@ public RSocket rSocket() { public Errors errors() { return errors; } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } } static class Errors implements Consumer { diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java new file mode 100644 index 000000000..e91fce848 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -0,0 +1,99 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class PayloadValidationUtilsTest { + + @Test + void shouldBeValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size() + + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2]; + byte[] data = + new byte + [FrameLengthFlyweight.FRAME_LENGTH_MASK / 2 + - FrameLengthFlyweight.FRAME_LENGTH_SIZE + - FrameHeaderFlyweight.size() + - FrameHeaderFlyweight.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation1() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation2() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation3() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation4() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java similarity index 92% rename from rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 3af8916cd..51f5afc24 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,22 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.frame.FrameType.ERROR; import static io.rsocket.frame.FrameType.SETUP; import static org.assertj.core.data.Offset.offset; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; -import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; -import io.rsocket.exceptions.MissingLeaseException; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameType; import io.rsocket.frame.LeaseFrameFlyweight; @@ -34,7 +37,8 @@ import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.lease.*; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.lease.MissingLeaseException; +import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestClientTransport; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestServerTransport; @@ -73,26 +77,26 @@ class RSocketLeaseTest { @BeforeEach void setUp() { - connection = new TestDuplexConnection(); PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; - byteBufAllocator = UnpooledByteBufAllocator.DEFAULT; + byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(byteBufAllocator); requesterLeaseHandler = new RequesterLeaseHandler.Impl(TAG, leases -> leaseReceiver = leases); responderLeaseHandler = new ResponderLeaseHandler.Impl<>( TAG, byteBufAllocator, stats -> leaseSender, err -> {}, Optional.empty()); ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, new PluginRegistry(), true); + new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); rSocketRequester = new RSocketRequester( - byteBufAllocator, multiplexer.asClientConnection(), payloadDecoder, err -> {}, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, requesterLeaseHandler); @@ -105,12 +109,12 @@ void setUp() { rSocketResponder = new RSocketResponder( - byteBufAllocator, multiplexer.asServerConnection(), mockRSocketHandler, payloadDecoder, err -> {}, - responderLeaseHandler); + responderLeaseHandler, + 0); } @Test @@ -127,12 +131,7 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { payload); TestServerTransport transport = new TestServerTransport(); - Closeable server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new AbstractRSocket() {})) - .transport(transport) - .start() - .block(); + RSocketServer.create().bind(transport).block(); TestDuplexConnection connection = transport.connect(); connection.addToReceivedBuffer(setupFrame); @@ -148,7 +147,7 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { @Test public void clientRSocketFactorySetsLeaseFlag() { TestClientTransport clientTransport = new TestClientTransport(); - RSocketFactory.connect().lease().transport(clientTransport).start().block(); + RSocketConnector.create().lease(Leases::new).connect(clientTransport).block(); Collection sent = clientTransport.testConnection().getSent(); Assertions.assertThat(sent).hasSize(1); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java new file mode 100644 index 000000000..dc76b5450 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static org.junit.Assert.assertEquals; + +import io.rsocket.RSocket; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import java.io.UncheckedIOException; +import java.time.Duration; +import java.util.Iterator; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Exceptions; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketReconnectTest { + + private Queue retries = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldBeASharedReconnectableInstanceOfRSocketMono() { + TestClientTransport[] testClientTransport = + new TestClientTransport[] {new TestClientTransport()}; + Mono rSocketMono = + RSocketConnector.create() + .reconnect(Retry.indefinitely()) + .connect(() -> testClientTransport[0]); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + testClientTransport[0] = new TestClientTransport(); + + RSocket rSocket3 = rSocketMono.block(); + RSocket rSocket4 = rSocketMono.block(); + + Assertions.assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + } + + @Test + @SuppressWarnings({"rawtype", "unchecked"}) + public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + Mockito.when(transport.connect(0)) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(new TestClientTransport().connect(0)); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + } + + @Test + @SuppressWarnings({"rawtype", "unchecked"}) + public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + Mockito.when(transport.connect(0)) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(new TestClientTransport().connect(0)); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + } + + @Test + public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { + + Mono rSocketMono = RSocketConnector.connectWith(new TestClientTransport()); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isNotEqualTo(rSocket2); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertEquals(exceptions.length, retries.size()); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertEquals(index, retryContext.totalRetries()); + assertEquals(exceptions[index], retryContext.failure().getClass()); + index++; + } + } + + Consumer onRetry() { + return context -> retries.add(context); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java similarity index 77% rename from rsocket-core/src/test/java/io/rsocket/RSocketRequesterSubscribersTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index b49dbe809..01cf99e26 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -14,10 +14,12 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; @@ -38,7 +40,6 @@ import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import reactor.test.StepVerifier; class RSocketRequesterSubscribersTest { @@ -52,21 +53,23 @@ class RSocketRequesterSubscribersTest { FrameType.REQUEST_STREAM, FrameType.REQUEST_CHANNEL)); + private LeaksTrackingByteBufAllocator allocator; private RSocket rSocketRequester; private TestDuplexConnection connection; @BeforeEach void setUp() { - connection = new TestDuplexConnection(); + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); rSocketRequester = new RSocketRequester( - ByteBufAllocator.DEFAULT, connection, PayloadDecoder.DEFAULT, err -> {}, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); } @@ -75,9 +78,15 @@ void setUp() { @MethodSource("allInteractions") void multiSubscriber(Function> interaction) { RSocket multiSubsRSocket = new MultiSubscriberRSocket(rSocketRequester); - Flux response = Flux.from(interaction.apply(multiSubsRSocket)).take(Duration.ofMillis(10)); - StepVerifier.create(response).expectComplete().verify(Duration.ofSeconds(5)); - StepVerifier.create(response).expectComplete().verify(Duration.ofSeconds(5)); + Flux response = Flux.from(interaction.apply(multiSubsRSocket)); + StepVerifier.withVirtualTime(() -> response.take(Duration.ofMillis(10))) + .thenAwait(Duration.ofMillis(10)) + .expectComplete() + .verify(Duration.ofSeconds(5)); + StepVerifier.withVirtualTime(() -> response.take(Duration.ofMillis(10))) + .thenAwait(Duration.ofMillis(10)) + .expectComplete() + .verify(Duration.ofSeconds(5)); Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(2); } @@ -85,9 +94,13 @@ void multiSubscriber(Function> interaction) { @ParameterizedTest @MethodSource("allInteractions") void singleSubscriber(Function> interaction) { - Flux response = Flux.from(interaction.apply(rSocketRequester)).take(Duration.ofMillis(10)); - StepVerifier.create(response).expectComplete().verify(Duration.ofSeconds(5)); - StepVerifier.create(response) + Flux response = Flux.from(interaction.apply(rSocketRequester)); + StepVerifier.withVirtualTime(() -> response.take(Duration.ofMillis(10))) + .thenAwait(Duration.ofMillis(10)) + .expectComplete() + .verify(Duration.ofSeconds(5)); + StepVerifier.withVirtualTime(() -> response.take(Duration.ofMillis(10))) + .thenAwait(Duration.ofMillis(10)) .expectError(IllegalStateException.class) .verify(Duration.ofSeconds(5)); @@ -114,7 +127,7 @@ static Stream>> allInteractions() { rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), rSocket -> rSocket.requestStream(DefaultPayload.create("test")), - rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + // rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), rSocket -> rSocket.metadataPush(DefaultPayload.create("test"))); } } diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java similarity index 93% rename from rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java index a2c17cf95..de6f86c57 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -1,6 +1,8 @@ -package io.rsocket; +package io.rsocket.core; -import io.rsocket.RSocketRequesterTest.ClientSocketRule; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; import io.rsocket.util.EmptyPayload; import java.nio.channels.ClosedChannelException; import java.time.Duration; diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java new file mode 100644 index 000000000..3b62bc437 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -0,0 +1,836 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderFlyweight.frameType; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestNFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import io.rsocket.util.MultiSubscriberRSocket; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.UnicastProcessor; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketRequesterTest { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + + @Test + @Timeout(2_000) + public void testInvalidFrameOnStream0() { + rule.connection.addToReceivedBuffer(RequestNFrameFlyweight.encode(rule.alloc(), 0, 10)); + assertThat("Unexpected errors.", rule.errors, hasSize(1)); + assertThat( + "Unexpected error received.", + rule.errors, + contains(instanceOf(IllegalStateException.class))); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testStreamInitialN() { + Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); + + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + // don't request here + } + }; + stream.subscribe(subscriber); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + subscriber.request(5); + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat("sent frame count", sent.size(), is(1)); + + ByteBuf f = sent.get(0); + + assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); + assertThat("initial request n", RequestStreamFrameFlyweight.initialRequestN(f), is(5L)); + assertThat("should be released", f.release(), is(true)); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleSetupException() { + rule.connection.addToReceivedBuffer( + ErrorFrameFlyweight.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); + assertThat("Unexpected errors.", rule.errors, hasSize(1)); + assertThat( + "Unexpected error received.", + rule.errors, + contains(instanceOf(RejectedSetupException.class))); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleApplicationException() { + rule.connection.clearSendReceiveBuffers(); + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(); + response.subscribe(responseSub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + ErrorFrameFlyweight.encode(rule.alloc(), streamId, new ApplicationErrorException("error"))); + + verify(responseSub).onError(any(ApplicationErrorException.class)); + + Assertions.assertThat(rule.connection.getSent()) + // requestResponseFrame + .hasSize(1) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleValidFrame() { + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameFlyweight.encodeNextReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + + verify(sub).onComplete(); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyWithCancel() { + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + + try { + response.block(Duration.ofMillis(100)); + } catch (IllegalStateException ise) { + } + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat( + "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); + assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); + Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Disabled("invalid") + @Timeout(2_000) + public void testRequestReplyErrorOnSend() { + rule.connection.setAvailability(0); // Fails send + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(10); + response.subscribe(responseSub); + + this.rule.assertNoConnectionErrors(); + + verify(responseSub).onSubscribe(any(Subscription.class)); + + rule.assertHasNoLeaks(); + // TODO this should get the error reported through the response subscription + // verify(responseSub).onError(any(RuntimeException.class)); + } + + @Test + @Timeout(2_000) + public void testLazyRequestResponse() { + Publisher response = + new MultiSubscriberRSocket(rule.socket).requestResponse(EmptyPayload.INSTANCE); + int streamId = sendRequestResponse(response); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + rule.connection.clearSendReceiveBuffers(); + int streamId2 = sendRequestResponse(response); + assertThat("Stream ID reused.", streamId2, not(equalTo(streamId))); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation() { + MonoProcessor cancelled = MonoProcessor.create(); + Flux request = Flux.never().doOnCancel(cancelled::onComplete); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation2() { + MonoProcessor cancelled = MonoProcessor.create(); + Flux request = + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::onComplete); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testChannelRequestServerSideCancellation() { + MonoProcessor cancelled = MonoProcessor.create(); + UnicastProcessor request = UnicastProcessor.create(); + request.onNext(EmptyPayload.INSTANCE); + rule.socket.requestChannel(request).subscribe(cancelled); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + rule.connection.addToReceivedBuffer(CancelFrameFlyweight.encode(rule.alloc(), streamId)); + rule.connection.addToReceivedBuffer( + PayloadFrameFlyweight.encodeComplete(rule.alloc(), streamId)); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + + Assertions.assertThat(request.isDisposed()).isTrue(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_CHANNEL) + .matches(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testCorrectFrameOrder() { + MonoProcessor delayer = MonoProcessor.create(); + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) {} + }; + rule.socket + .requestChannel( + Flux.concat(Flux.just(0).delayUntil(i -> delayer), Flux.range(1, 999)) + .map(i -> DefaultPayload.create(i + ""))) + .subscribe(subscriber); + + subscriber.request(1); + subscriber.request(Long.MAX_VALUE); + delayer.onComplete(); + + Iterator iterator = rule.connection.getSent().iterator(); + + ByteBuf initialFrame = iterator.next(); + + Assertions.assertThat(FrameHeaderFlyweight.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + Assertions.assertThat(RequestChannelFrameFlyweight.initialRequestN(initialFrame)) + .isEqualTo(Long.MAX_VALUE); + Assertions.assertThat( + RequestChannelFrameFlyweight.data(initialFrame).toString(CharsetUtil.UTF_8)) + .isEqualTo("0"); + Assertions.assertThat(initialFrame.release()).isTrue(); + + Assertions.assertThat(iterator.hasNext()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + rule.assertHasNoLeaks(); + }); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + + @Test + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + .expectSubscription() + .then( + () -> + rule.connection.addToReceivedBuffer( + RequestNFrameFlyweight.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + Assertions.assertThat(rule.connection.getSent()) + // expect to be sent RequestChannelFrame + // expect to be sent CancelFrame + .hasSize(2) + .allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("racingCases") + public void checkNoLeaksOnRacing( + Function> initiator, + BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < 10000; i++) { + ClientSocketRule clientSocketRule = new ClientSocketRule(); + try { + clientSocketRule + .apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + } + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + if (payloadP instanceof Flux) { + ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + Assertions.assertThat(clientSocketRule.connection.getSent()) + .allMatch(ReferenceCounted::release); + + clientSocketRule.assertHasNoLeaks(); + } + } + + private static Stream racingCases() { + return Stream.of( + Arguments.of( + (Function>) + (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameFlyweight.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + PayloadFrameFlyweight.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + + return rule.socket.requestStream(payload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_STREAM, + "Expected first frame matches {" + + REQUEST_STREAM + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + return rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + sink.complete(); + return ++index; + })); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_CHANNEL, + "Expected first frame matches {" + + REQUEST_CHANNEL + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + ErrorFrameFlyweight.encode(allocator, streamId, new RuntimeException("test")); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + ByteBuf frame = + PayloadFrameFlyweight.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + })); + } + + @Test + public void simpleOnDiscardRequestChannelTest() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next( + ByteBufPayload.create("d", "m"), + ByteBufPayload.create("d1", "m1"), + ByteBufPayload.create("d2", "m2")); + + assertSubscriber.cancel(); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleOnDiscardRequestChannelTest2() { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next(ByteBufPayload.create("d", "m")); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2")); + + rule.connection.addToReceivedBuffer( + ErrorFrameFlyweight.encode( + allocator, streamId, new CustomRSocketException(0x00000404, "test"))); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(responsesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + Publisher response; + + switch (frameType) { + case REQUEST_FNF: + response = + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p).then(Mono.empty())); + break; + case REQUEST_RESPONSE: + response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); + break; + case REQUEST_STREAM: + response = testPublisher.mono().flatMapMany(p -> rule.socket.requestStream(p)); + break; + case REQUEST_CHANNEL: + response = rule.socket.requestChannel(testPublisher.flux()); + break; + default: + throw new UnsupportedOperationException("illegal case"); + } + + response.subscribe(assertSubscriber); + testPublisher.next(ByteBufPayload.create("d")); + + int streamId = rule.getStreamIdForRequestType(frameType); + + if (responsesCnt > 0) { + for (int i = 0; i < responsesCnt - 1; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameFlyweight.encode( + allocator, + streamId, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("rd" + (i + 1)).getBytes()))); + } + + rule.connection.addToReceivedBuffer( + PayloadFrameFlyweight.encode( + allocator, + streamId, + false, + true, + true, + null, + Unpooled.wrappedBuffer(("rd" + responsesCnt).getBytes()))); + } + + if (framesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameFlyweight.encode(allocator, streamId, framesCnt)); + } + + for (int i = 1; i < framesCnt; i++) { + testPublisher.next(ByteBufPayload.create("d" + i)); + } + + Assertions.assertThat(rule.connection.getSent()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) + .hasSize(framesCnt) + .allMatch(bb -> !FrameHeaderFlyweight.hasMetadata(bb)) + .allMatch(ByteBuf::release); + + Assertions.assertThat(assertSubscriber.isTerminated()) + .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) + .isTrue(); + + Assertions.assertThat(assertSubscriber.values()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames received", + frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(p -> p.release()); + + rule.assertHasNoLeaks(); + rule.connection.clearSendReceiveBuffers(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + public int sendRequestResponse(Publisher response) { + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameFlyweight.encodeNextCompleteReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + verify(sub).onNext(any(Payload.class)); + verify(sub).onComplete(); + return streamId; + } + + public static class ClientSocketRule extends AbstractSocketRule { + @Override + protected RSocketRequester newRSocket() { + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + throwable -> errors.add(throwable), + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None); + } + + public int getStreamIdForRequestType(FrameType expectedFrameType) { + assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); + List framesFound = new ArrayList<>(); + for (ByteBuf frame : connection.getSent()) { + FrameType frameType = frameType(frame); + if (frameType == expectedFrameType) { + return FrameHeaderFlyweight.streamId(frame); + } + framesFound.add(frameType); + } + throw new AssertionError( + "No frames sent with frame type: " + + expectedFrameType + + ", frames found: " + + framesFound); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java new file mode 100644 index 000000000..c19456548 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -0,0 +1,795 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderFlyweight.frameType; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.AbstractRSocket; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameFlyweight; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestNFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.util.Collection; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketResponderTest { + + ServerSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped(t -> {}); + rule = new ServerSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleKeepAlive() throws Exception { + rule.connection.addToReceivedBuffer( + KeepAliveFrameFlyweight.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); + ByteBuf sent = rule.connection.awaitSend(); + assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); + /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ + assertThat( + "Unexpected keep-alive frame respond flag.", + KeepAliveFrameFlyweight.respondFlag(sent), + is(false)); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleResponseFrameNoError() throws Exception { + final int streamId = 4; + rule.connection.clearSendReceiveBuffers(); + + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + Collection> sendSubscribers = rule.connection.getSendSubscribers(); + assertThat("Request not sent.", sendSubscribers, hasSize(1)); + assertThat("Unexpected error.", rule.errors, is(empty())); + Subscriber sendSub = sendSubscribers.iterator().next(); + assertThat( + "Unexpected frame sent.", + frameType(rule.connection.awaitSend()), + anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandlerEmitsError() throws Exception { + final int streamId = 4; + rule.sendRequest(streamId, FrameType.REQUEST_STREAM); + assertThat("Unexpected error.", rule.errors, is(empty())); + assertThat( + "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); + } + + @Test + @Timeout(20_000) + public void testCancel() { + ByteBufAllocator allocator = rule.alloc(); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return Mono.never().doOnCancel(() -> cancelled.set(true)); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + assertThat("Unexpected error.", rule.errors, is(empty())); + assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + + rule.connection.addToReceivedBuffer(CancelFrameFlyweight.encode(allocator, streamId)); + + assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final AbstractRSocket acceptingSocket = + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload p) { + p.release(); + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + p.release(); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .doOnNext(Payload::release) + .subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(1); + } + }); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + rule.connection.clearSendReceiveBuffers(); + runnable.run(); + Assertions.assertThat(rule.errors) + .first() + .isInstanceOf(IllegalArgumentException.class) + .hasToString("java.lang.IllegalArgumentException: " + INVALID_PAYLOAD_ERROR_MESSAGE); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) + .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)) + .matches(ReferenceCounted::release); + + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + } + + rule.assertHasNoLeaks(); + } + + @Test + public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return Flux.never(); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + }, + assertSubscriber::cancel); + + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + ByteBuf requestNFrame = RequestNFrameFlyweight.encode(allocator, 1, Integer.MAX_VALUE); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + + return Flux.create( + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); + + ByteBuf requestNFrame = RequestNFrameFlyweight.encode(allocator, 1, Integer.MAX_VALUE); + + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.error(new RuntimeException()); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + sources[0] = new Operators.MonoSubscriber<>(actual); + actual.onSubscribe(sources[0]); + } + }; + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_RESPONSE); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sources[0].complete(ByteBufPayload.create("d1", "m1")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void simpleDiscardRequestStreamTest() { + ByteBufAllocator allocator = rule.alloc(); + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + FluxSink sink = sinks[0]; + + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleDiscardRequestChannelTest() { + ByteBufAllocator allocator = rule.alloc(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return (Flux) payloads; + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("de3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(framesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Mono fireAndForget(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return testPublisher.flux(); + } + }, + 1); + + rule.sendRequest(1, frameType, ByteBufPayload.create("d")); + + // if responses number is bigger than 1 we have to send one extra requestN + if (responsesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameFlyweight.encode(allocator, 1, responsesCnt - 1)); + } + + // respond with specific number of elements + for (int i = 0; i < responsesCnt; i++) { + testPublisher.next(ByteBufPayload.create("rd" + i)); + } + + // Listen to incoming frames. Valid for RequestChannel case only + if (framesCnt > 1) { + for (int i = 1; i < responsesCnt; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameFlyweight.encode( + allocator, + 1, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("d" + (i + 1)).getBytes()))); + } + } + + if (responsesCnt > 0) { + Assertions.assertThat( + rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(bb -> !FrameHeaderFlyweight.hasMetadata(bb)); + } + + if (framesCnt > 1) { + Assertions.assertThat( + rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", + frameType, framesCnt - 1) + .hasSize(1) + .first() + .matches(bb -> RequestNFrameFlyweight.requestN(bb) == (framesCnt - 1)); + } + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + .hasSize(framesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + public static class ServerSocketRule extends AbstractSocketRule { + + private RSocket acceptingSocket; + private volatile int prefetch; + + @Override + protected void init() { + acceptingSocket = + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + }; + super.init(); + } + + public void setAcceptingSocket(RSocket acceptingSocket) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + errors = new ConcurrentLinkedQueue<>(); + this.prefetch = Integer.MAX_VALUE; + super.init(); + } + + public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + errors = new ConcurrentLinkedQueue<>(); + this.prefetch = prefetch; + super.init(); + } + + @Override + protected RSocketResponder newRSocket() { + return new RSocketResponder( + connection, + acceptingSocket, + PayloadDecoder.ZERO_COPY, + throwable -> errors.add(throwable), + ResponderLeaseHandler.None, + 0); + } + + private void sendRequest(int streamId, FrameType frameType) { + sendRequest(streamId, frameType, EmptyPayload.INSTANCE); + } + + private void sendRequest(int streamId, FrameType frameType, Payload payload) { + ByteBuf request; + + switch (frameType) { + case REQUEST_CHANNEL: + request = + RequestChannelFrameFlyweight.encodeReleasingPayload( + allocator, streamId, false, prefetch, payload); + break; + case REQUEST_STREAM: + request = + RequestStreamFrameFlyweight.encodeReleasingPayload( + allocator, streamId, prefetch, payload); + break; + case REQUEST_RESPONSE: + request = + RequestResponseFrameFlyweight.encodeReleasingPayload(allocator, streamId, payload); + break; + case REQUEST_FNF: + request = + RequestFireAndForgetFrameFlyweight.encodeReleasingPayload( + allocator, streamId, payload); + break; + default: + throw new IllegalArgumentException("unsupported type: " + frameType); + } + + connection.addToReceivedBuffer(request); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java new file mode 100644 index 000000000..9d105a8c9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -0,0 +1,43 @@ +package io.rsocket.core; + +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestServerTransport; +import org.assertj.core.api.Assertions; +import org.junit.Test; + +public class RSocketServerFragmentationTest { + + @Test + public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketServer.create().fragment(2)) + .withMessage("smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { + RSocketServer.create().fragment(100).bind(new TestServerTransport()).block(); + } + + @Test + public void serverSucceedsWithDisabledFragmentation() { + RSocketServer.create().bind(new TestServerTransport()).block(); + } + + @Test + public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketConnector.create().fragment(2)) + .withMessage("smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { + RSocketConnector.create().fragment(100).connect(TestClientTransport::new).block(); + } + + @Test + public void clientSucceedsWithDisabledFragmentation() { + RSocketConnector.connectWith(new TestClientTransport()).block(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java new file mode 100644 index 000000000..4a2c43ef8 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -0,0 +1,527 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.AbstractRSocket; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.test.util.LocalDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.assertj.core.api.Assertions; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; +import org.mockito.ArgumentCaptor; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import reactor.core.publisher.DirectProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; + +public class RSocketTest { + + @Rule public final SocketRule rule = new SocketRule(); + + public static void assertError(String s, String mode, ArrayList errors) { + for (Throwable t : errors) { + if (t.toString().equals(s)) { + return; + } + } + + Assert.fail("Expected " + mode + " connection error: " + s + " other errors " + errors.size()); + } + + @Test(timeout = 2_000) + public void testRequestReplyNoError() { + StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) + .expectNextCount(1) + .expectComplete() + .verify(); + } + + @Test(timeout = 2000) + public void testHandlerEmitsError() { + rule.setRequestAcceptor( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error(new NullPointerException("Deliberate exception.")); + } + }); + Subscriber subscriber = TestSubscriber.create(); + rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); + verify(subscriber).onError(any(ApplicationErrorException.class)); + + // Client sees error through normal API + rule.assertNoClientErrors(); + + rule.assertServerError("java.lang.NullPointerException: Deliberate exception."); + } + + @Test(timeout = 2000) + public void testHandlerEmitsCustomError() { + rule.setRequestAcceptor( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error( + new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); + } + }); + Subscriber subscriber = TestSubscriber.create(); + rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); + ArgumentCaptor customRSocketExceptionArgumentCaptor = + ArgumentCaptor.forClass(CustomRSocketException.class); + verify(subscriber).onError(customRSocketExceptionArgumentCaptor.capture()); + + Assert.assertEquals( + "Deliberate Custom exception.", + customRSocketExceptionArgumentCaptor.getValue().getMessage()); + Assert.assertEquals(0x00000501, customRSocketExceptionArgumentCaptor.getValue().errorCode()); + + // Client sees error through normal API + rule.assertNoClientErrors(); + + rule.assertServerError("CustomRSocketException (0x501): Deliberate Custom exception."); + } + + @Test(timeout = 2000) + public void testRequestPropagatesCorrectlyForRequestChannel() { + rule.setRequestAcceptor( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + // specifically limits request to 3 in order to prevent 256 request from limitRate + // hidden on the responder side + .limitRequest(3); + } + }); + + Flux.range(0, 3) + .map(i -> DefaultPayload.create("" + i)) + .as(rule.crs::requestChannel) + .as(publisher -> StepVerifier.create(publisher, 3)) + .expectSubscription() + .expectNextCount(3) + .expectComplete() + .verify(Duration.ofMillis(5000)); + + rule.assertNoClientErrors(); + rule.assertNoServerErrors(); + } + + @Test(timeout = 2000) + public void testStream() throws Exception { + Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test(timeout = 2000) + public void testChannel() throws Exception { + Flux requests = + Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test(timeout = 2000) + public void testErrorPropagatesCorrectly() { + AtomicReference error = new AtomicReference<>(); + rule.setRequestAcceptor( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).doOnError(error::set); + } + }); + Flux requests = Flux.error(new RuntimeException("test")); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectErrorMessage("test").verify(); + Assertions.assertThat(error.get()).isNull(); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_CancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void + requestChannelCase_CompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_ensureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + cancelFromRequesterSubscriber( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + void initRequestChannelCase( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + rule.setRequestAcceptor( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(responderSubscriber); + return responderPublisher.flux(); + } + }); + + rule.crs.requestChannel(requesterPublisher).subscribe(requesterSubscriber); + + requesterPublisher.assertWasSubscribed(); + requesterSubscriber.assertSubscribed(); + + responderSubscriber.assertNotSubscribed(); + responderPublisher.assertWasNotSubscribed(); + + // firstRequest + requesterSubscriber.request(1); + requesterPublisher.assertMaxRequested(1); + requesterPublisher.next(DefaultPayload.create("initialData", "initialMetadata")); + + responderSubscriber.assertSubscribed(); + responderPublisher.assertWasSubscribed(); + } + + void nextFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that outerUpstream and innerSubscriber is not terminated so the requestChannel + requesterPublisher.assertSubscribers(1); + responderSubscriber.assertNotTerminated(); + + responderSubscriber.request(6); + requesterPublisher.next( + DefaultPayload.create("d1", "m1"), + DefaultPayload.create("d2"), + DefaultPayload.create("d3", "m3"), + DefaultPayload.create("d4"), + DefaultPayload.create("d5", "m5")); + + List innerPayloads = responderSubscriber.awaitAndAssertNextValueCount(6).values(); + Assertions.assertThat(innerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("initialData", "d1", "d2", "d3", "d4", "d5"); + Assertions.assertThat(innerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, true, false, true, false, true); + Assertions.assertThat(innerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("initialMetadata", "m1", "", "m3", "", "m5"); + } + + void completeFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + requesterPublisher.complete(); + responderSubscriber.assertTerminated(); + requesterPublisher.assertNoSubscribers(); + } + + void cancelFromResponderSubscriber( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + responderSubscriber.cancel(); + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void nextFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that downstream is not terminated so the requestChannel state is half-closed + responderPublisher.assertSubscribers(1); + requesterSubscriber.assertNotTerminated(); + + // ensures responderPublisher can send messages and outerSubscriber can receive them + requesterSubscriber.request(5); + responderPublisher.next( + DefaultPayload.create("rd1", "rm1"), + DefaultPayload.create("rd2"), + DefaultPayload.create("rd3", "rm3"), + DefaultPayload.create("rd4"), + DefaultPayload.create("rd5", "rm5")); + + List outerPayloads = requesterSubscriber.awaitAndAssertNextValueCount(5).values(); + Assertions.assertThat(outerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("rd1", "rd2", "rd3", "rd4", "rd5"); + Assertions.assertThat(outerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, false, true, false, true); + Assertions.assertThat(outerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("rm1", "", "rm3", "", "rm5"); + } + + void completeFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that after sending complete inner upstream is closed + responderPublisher.complete(); + requesterSubscriber.assertTerminated(); + responderPublisher.assertNoSubscribers(); + } + + void cancelFromRequesterSubscriber( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterSubscriber.cancel(); + // error should be propagated + responderSubscriber.assertTerminated(); + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + public static class SocketRule extends ExternalResource { + + DirectProcessor serverProcessor; + DirectProcessor clientProcessor; + private RSocketRequester crs; + + @SuppressWarnings("unused") + private RSocketResponder srs; + + private RSocket requestAcceptor; + private ArrayList clientErrors = new ArrayList<>(); + private ArrayList serverErrors = new ArrayList<>(); + + private LeaksTrackingByteBufAllocator allocator; + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + init(); + base.evaluate(); + } + }; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + protected void init() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + serverProcessor = DirectProcessor.create(); + clientProcessor = DirectProcessor.create(); + + LocalDuplexConnection serverConnection = + new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); + LocalDuplexConnection clientConnection = + new LocalDuplexConnection("client", allocator, serverProcessor, clientProcessor); + + requestAcceptor = + null != requestAcceptor + ? requestAcceptor + : new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10) + .map( + i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")) + .subscribe(); + + return Flux.range(1, 10) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + }; + + srs = + new RSocketResponder( + serverConnection, + requestAcceptor, + PayloadDecoder.DEFAULT, + throwable -> serverErrors.add(throwable), + ResponderLeaseHandler.None, + 0); + + crs = + new RSocketRequester( + clientConnection, + PayloadDecoder.DEFAULT, + throwable -> clientErrors.add(throwable), + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None); + } + + public void setRequestAcceptor(RSocket requestAcceptor) { + this.requestAcceptor = requestAcceptor; + init(); + } + + public void assertNoErrors() { + assertNoClientErrors(); + assertNoServerErrors(); + } + + public void assertNoClientErrors() { + MatcherAssert.assertThat( + "Unexpected error on the client connection.", clientErrors, is(empty())); + } + + public void assertNoServerErrors() { + MatcherAssert.assertThat( + "Unexpected error on the server connection.", serverErrors, is(empty())); + } + + public void assertClientError(String s) { + assertError(s, "client", this.clientErrors); + } + + public void assertServerError(String s) { + assertError(s, "server", this.serverErrors); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java new file mode 100644 index 000000000..968a1a793 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -0,0 +1,868 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.mockito.Mockito; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ReconnectMonoTests { + + private Queue retries = new ConcurrentLinkedQueue<>(); + private Queue> received = new ConcurrentLinkedQueue<>(); + private Queue expired = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldExpireValueOnRacingDisposeAndNext() { + Hooks.onErrorDropped(t -> {}); + Hooks.onNextDropped(System.out::println); + for (int i = 0; i < 100000; i++) { + final int index = i; + final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; + Subscription mockSubscription = Mockito.mock(Subscription.class); + final Mono stringMono = + new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + actual.onSubscribe(mockSubscription); + monoSubscribers[0] = actual; + } + }; + + final ReconnectMono reconnectMono = + stringMono + .doOnDiscard(Object.class, System.out::println) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + Mockito.verify(mockSubscription).cancel(); + + if (processor.isError()) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + + Assertions.assertThat(expired).containsOnly("value" + i); + } else { + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final MonoProcessor racerProcessor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(racerProcessor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = MonoProcessor.create(); + final MonoProcessor racerProcessor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + RaceTestUtils.race( + () -> reconnectMono.subscribe(processor), () -> reconnectMono.subscribe(racerProcessor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + Assertions.assertThat(racerProcessor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + String[] values = new String[1]; + + RaceTestUtils.race( + () -> values[0] = reconnectMono.block(timeout), () -> reconnectMono.subscribe(processor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(values).containsExactly("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + String[] values1 = new String[1]; + String[] values2 = new String[1]; + + RaceTestUtils.race( + () -> values1[0] = reconnectMono.block(timeout), + () -> values2[0] = reconnectMono.block(timeout)); + + Assertions.assertThat(values2).containsExactly("value" + i); + Assertions.assertThat(values1).containsExactly("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add( + new ReconnectMono.ReconnectInner<>(MonoProcessor.create(), reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Throwable error = processor.getError(); + + if (error instanceof CancellationException) { + Assertions.assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(error) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Unexpected Completion of the Upstream"); + } + + Assertions.assertThat(expired).isEmpty(); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + if (processor.getError() instanceof CancellationException) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(processor.getError()) + .isInstanceOf(RuntimeException.class) + .hasMessage("test"); + } + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono() + .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + + if (processor.getError() instanceof CancellationException) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(processor.getError()) + .matches(t -> Exceptions.isRetryExhausted(t)) + .hasCause(runtimeException); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldThrowOnBlocking() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read"); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + publisher.error(new RuntimeException("test")); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("ReconnectMono terminated with an error")); + } + + @Test + public void shouldBeScannable() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final Mono parent = publisher.mono(); + final ReconnectMono reconnectMono = + parent.as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final Scannable scannableOfReconnect = Scannable.from(reconnectMono); + + Assertions.assertThat( + (List) + scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) + .hasSize(1) + .containsExactly(publisher.mono().getClass()); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) + .isEqualTo(false); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + final Scannable scannableOfMonoProcessor = Scannable.from(processor); + + Assertions.assertThat( + (List) + scannableOfMonoProcessor + .parents() + .map(s -> s.getClass()) + .collect(Collectors.toList())) + .hasSize(3) + .containsExactly( + ReconnectMono.ReconnectInner.class, ReconnectMono.class, publisher.mono().getClass()); + + reconnectMono.dispose(); + + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) + .isEqualTo(true); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + .isInstanceOf(CancellationException.class); + } + + @Test + public void shouldNotExpiredIfNotCompleted() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + reconnectMono.invalidate(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + publisher.assertSubscribers(1); + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isTrue(); + + publisher.assertSubscribers(0); + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldNotEmitUntilCompletion() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isTrue(); + Assertions.assertThat(processor.peek()).isEqualTo("test"); + } + + @Test + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + processor.cancel(); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.EMPTY_SUBSCRIBED); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isFalse(); + Assertions.assertThat(processor.peek()).isNull(); + } + + @Test + public void shouldExpireValueOnDispose() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono) + .expectSubscription() + .then(() -> publisher.next("value")) + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + + reconnectMono.dispose(); + + Assertions.assertThat(expired).hasSize(1); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(reconnectMono.isDisposed()).isTrue(); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void shouldNotifyAllTheSubscribers() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor sub1 = MonoProcessor.create(); + final MonoProcessor sub2 = MonoProcessor.create(); + final MonoProcessor sub3 = MonoProcessor.create(); + final MonoProcessor sub4 = MonoProcessor.create(); + + reconnectMono.subscribe(sub1); + reconnectMono.subscribe(sub2); + reconnectMono.subscribe(sub3); + reconnectMono.subscribe(sub4); + + Assertions.assertThat(reconnectMono.subscribers).hasSize(4); + + final ArrayList> processors = new ArrayList<>(200); + + for (int i = 0; i < 100; i++) { + final MonoProcessor subA = MonoProcessor.create(); + final MonoProcessor subB = MonoProcessor.create(); + processors.add(subA); + processors.add(subB); + RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); + } + + Assertions.assertThat(reconnectMono.subscribers).hasSize(204); + + sub1.dispose(); + + Assertions.assertThat(reconnectMono.subscribers).hasSize(203); + + publisher.next("value"); + + Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); + Assertions.assertThat(sub2.peek()).isEqualTo("value"); + Assertions.assertThat(sub3.peek()).isEqualTo("value"); + Assertions.assertThat(sub4.peek()).isEqualTo("value"); + + for (MonoProcessor sub : processors) { + Assertions.assertThat(sub.peek()).isEqualTo("value"); + Assertions.assertThat(sub.isTerminated()).isTrue(); + } + + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldExpireValueExactlyOnce() { + for (int i = 0; i < 1000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); + + Assertions.assertThat(expired).hasSize(1).containsOnly("value"); + Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).hasSize(1).containsOnly("value"); + Assertions.assertThat(received) + .hasSize(2) + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value", reconnectMono)); + + Assertions.assertThat(cold.subscribeCount()).isEqualTo(2); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldTimeoutRetryWithVirtualTime() { + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + // then + StepVerifier.withVirtualTime( + () -> + Mono.error(new RuntimeException("Something went wrong")) + .retryWhen( + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(minBackoff)) + .doAfterRetry(onRetry()) + .maxBackoff(Duration.ofSeconds(maxBackoff))) + .timeout(Duration.ofSeconds(timeout)) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) + .subscribeOn(Schedulers.elastic())) + .expectSubscription() + .thenAwait(Duration.ofSeconds(timeout)) + .expectError(TimeoutException.class) + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryNoBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.max(2).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); + assertRetries(IOException.class, IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryFixedBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.fixedDelay(1, Duration.ofMillis(500)).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .expectNoEvent(Duration.ofMillis(300)) + .thenAwait(Duration.ofMillis(300)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryExponentialBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .jitter(0.0d) + .doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .thenAwait(Duration.ofMillis(100)) + .thenAwait(Duration.ofMillis(200)) + .thenAwait(Duration.ofMillis(400)) + .thenAwait(Duration.ofMillis(500)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + Consumer onRetry() { + return context -> retries.add(context); + } + + BiConsumer onValue() { + return (v, __) -> received.add(Tuples.of(v, __)); + } + + Consumer onExpire() { + return (v) -> expired.add(v); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertEquals(exceptions.length, retries.size()); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertEquals(index, retryContext.totalRetries()); + assertEquals(exceptions[index], retryContext.failure().getClass()); + index++; + } + } + + static boolean isRetryExhausted(Throwable e, Class cause) { + return Exceptions.isRetryExhausted(e) && cause.isInstance(e.getCause()); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java similarity index 86% rename from rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java rename to rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index e6972eec0..db72c7775 100644 --- a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -1,10 +1,12 @@ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.transport.ServerTransport.ConnectionAcceptor; import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.rsocket.*; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.ErrorFrameFlyweight; @@ -31,7 +33,7 @@ void responderRejectSetup() { String errorMsg = "error"; RejectingAcceptor acceptor = new RejectingAcceptor(errorMsg); - RSocketFactory.receive().acceptor(acceptor).transport(transport).start().block(); + RSocketServer.create().acceptor(acceptor).bind(transport).block(); transport.connect(); @@ -46,17 +48,19 @@ void responderRejectSetup() { @Test void requesterStreamsTerminatedOnZeroErrorFrame() { - TestDuplexConnection conn = new TestDuplexConnection(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); List errors = new ArrayList<>(); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, conn, DefaultPayload::create, errors::add, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); @@ -82,16 +86,18 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { @Test void requesterNewStreamsTerminatedAfterZeroErrorFrame() { - TestDuplexConnection conn = new TestDuplexConnection(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, conn, DefaultPayload::create, err -> {}, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, RequesterLeaseHandler.None); @@ -129,7 +135,9 @@ public Mono senderRSocket() { private static class SingleConnectionTransport implements ServerTransport { - private final TestDuplexConnection conn = new TestDuplexConnection(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection conn = new TestDuplexConnection(allocator); @Override public Mono start(ConnectionAcceptor acceptor, int mtu) { @@ -147,8 +155,7 @@ public ByteBuf awaitSent() { public void connect() { Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); ByteBuf setup = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, false, 0, 42, "mdMime", "dMime", payload); + SetupFrameFlyweight.encode(allocator, false, 0, 42, "mdMime", "dMime", payload); conn.addToReceivedBuffer(setup); } diff --git a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java similarity index 99% rename from rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java rename to rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java index 766a6aaf7..00248b6d8 100644 --- a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; diff --git a/rsocket-core/src/test/java/io/rsocket/TestingStuff.java b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java similarity index 97% rename from rsocket-core/src/test/java/io/rsocket/TestingStuff.java rename to rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java index 64c790053..e0ebf5064 100644 --- a/rsocket-core/src/test/java/io/rsocket/TestingStuff.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java @@ -1,4 +1,4 @@ -package io.rsocket; +package io.rsocket.core; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -16,6 +16,6 @@ public void testStuff() { ByteBuf byteBuf = Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(f1)); System.out.println(ByteBufUtil.prettyHexDump(byteBuf)); - ConnectionSetupPayload.create(byteBuf); + new DefaultConnectionSetupPayload(byteBuf); } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java index 8c39e8250..ccf7649d2 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java @@ -17,27 +17,22 @@ package io.rsocket.exceptions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; interface RSocketExceptionTest { - @DisplayName("constructor throws NullPointerException with null message") + @DisplayName("constructor does not throw NullPointerException with null message") @Test default void constructorWithNullMessage() { - assertThatNullPointerException() - .isThrownBy(() -> getException(null)) - .withMessage("message must not be null"); + assertThat(getException(null)).hasMessage(null); } - @DisplayName("constructor throws NullPointerException with null message and cause") + @DisplayName("constructor does not throw NullPointerException with null message and cause") @Test default void constructorWithNullMessageAndCause() { - assertThatNullPointerException() - .isThrownBy(() -> getException(null, new Exception())) - .withMessage("message must not be null"); + assertThat(getException(null)).hasMessage(null); } @DisplayName("errorCode returns specified value") diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index 3d96bfd12..9050eaa90 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,18 +22,16 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.*; -import io.rsocket.util.DefaultPayload; -import java.util.Arrays; -import java.util.List; import java.util.concurrent.ThreadLocalRandom; import org.junit.Assert; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -50,20 +48,22 @@ final class FragmentationDuplexConnectionTest { private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + { + Mockito.when(delegate.onClose()).thenReturn(Mono.never()); + } + @SuppressWarnings("unchecked") private final ArgumentCaptor> publishers = ArgumentCaptor.forClass(Publisher.class); - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @Test void constructorInvalidMaxFragmentSize() { assertThatIllegalArgumentException() - .isThrownBy( - () -> - new FragmentationDuplexConnection( - delegate, allocator, Integer.MIN_VALUE, false, "")) + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, Integer.MIN_VALUE, false, "")) .withMessage("smallest allowed mtu size is 64 bytes, provided: -2147483648"); } @@ -71,236 +71,18 @@ void constructorInvalidMaxFragmentSize() { @Test void constructorMtuLessThanMin() { assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(delegate, allocator, 2, false, "")) + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, 2, false, "")) .withMessage("smallest allowed mtu size is 64 bytes, provided: 2"); } - @DisplayName("constructor throws NullPointerException with null byteBufAllocator") - @Test - void constructorNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(delegate, null, 64, false, "")) - .withMessage("byteBufAllocator must not be null"); - } - @DisplayName("constructor throws NullPointerException with null delegate") @Test void constructorNullDelegate() { assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(null, allocator, 64, false, "")) + .isThrownBy(() -> new FragmentationDuplexConnection(null, 64, false, "")) .withMessage("delegate must not be null"); } - @DisplayName("reassembles data") - @Test - void reassembleData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - List byteBufs = - Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - false, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameFlyweight.metadata(byteBuf); - Assert.assertEquals(metadata, m); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameFlyweight.metadata(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble a non-fragment frame") - @Test - void reassembleNonFragment() { - ByteBuf encode = - RequestResponseFrameFlyweight.encode( - allocator, 1, false, DefaultPayload.create(Unpooled.wrappedBuffer(data))); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals( - Unpooled.wrappedBuffer(data), RequestResponseFrameFlyweight.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble non fragmentable frame") - @Test - void reassembleNonFragmentableFrame() { - ByteBuf encode = CancelFrameFlyweight.encode(allocator, 2); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.CANCEL, FrameHeaderFlyweight.frameType(byteBuf)); - }) - .verifyComplete(); - } - @DisplayName("fragments data") @Test void sendData() { @@ -309,8 +91,9 @@ void sendData() { allocator, 1, false, Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(data)); when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); - new FragmentationDuplexConnection(delegate, allocator, 64, false, "").sendOne(encode.retain()); + new FragmentationDuplexConnection(delegate, 64, false, "").sendOne(encode.retain()); verify(delegate).send(publishers.capture()); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java index 984207936..a8569ef3b 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java @@ -28,7 +28,8 @@ public class FragmentationIntegrationTest { @Test void fragmentAndReassembleData() { ByteBuf frame = - PayloadFrameFlyweight.encodeNextComplete(allocator, 2, DefaultPayload.create(data)); + PayloadFrameFlyweight.encodeNextCompleteReleasingPayload( + allocator, 2, DefaultPayload.create(data)); System.out.println(FrameUtil.toString(frame)); frame.retain(); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java index f5a013357..c6b1735e6 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java @@ -20,7 +20,6 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.rsocket.frame.*; -import io.rsocket.util.DefaultPayload; import java.util.concurrent.ThreadLocalRandom; import org.junit.Assert; import org.junit.jupiter.api.DisplayName; @@ -43,14 +42,17 @@ final class FrameFragmenterTest { @Test void testGettingData() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf fnf = - RequestFireAndForgetFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestFireAndForgetFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf rs = - RequestStreamFrameFlyweight.encode(allocator, 1, true, 1, DefaultPayload.create(data)); + RequestStreamFrameFlyweight.encode( + allocator, 1, true, 1, null, Unpooled.wrappedBuffer(data)); ByteBuf rc = RequestChannelFrameFlyweight.encode( - allocator, 1, true, false, 1, DefaultPayload.create(data)); + allocator, 1, true, false, 1, null, Unpooled.wrappedBuffer(data)); ByteBuf data = FrameFragmenter.getData(rr, FrameType.REQUEST_RESPONSE); Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); @@ -73,16 +75,22 @@ void testGettingData() { void testGettingMetadata() { ByteBuf rr = RequestResponseFrameFlyweight.encode( - allocator, 1, true, DefaultPayload.create(data, metadata)); + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf fnf = RequestFireAndForgetFrameFlyweight.encode( - allocator, 1, true, DefaultPayload.create(data, metadata)); + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf rs = RequestStreamFrameFlyweight.encode( - allocator, 1, true, 1, DefaultPayload.create(data, metadata)); + allocator, 1, true, 1, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf rc = RequestChannelFrameFlyweight.encode( - allocator, 1, true, false, 1, DefaultPayload.create(data, metadata)); + allocator, + 1, + true, + false, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); @@ -104,7 +112,8 @@ void testGettingMetadata() { @Test void returnEmptBufferWhenNoMetadataPresent() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); @@ -115,7 +124,8 @@ void returnEmptBufferWhenNoMetadataPresent() { @Test void encodeFirstFrameWithData() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -144,7 +154,7 @@ void encodeFirstFrameWithData() { void encodeFirstWithDataChannel() { ByteBuf rc = RequestChannelFrameFlyweight.encode( - allocator, 1, true, false, 10, DefaultPayload.create(data)); + allocator, 1, true, false, 10, null, Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -173,7 +183,8 @@ void encodeFirstWithDataChannel() { @Test void encodeFirstWithDataStream() { ByteBuf rc = - RequestStreamFrameFlyweight.encode(allocator, 1, true, 50, DefaultPayload.create(data)); + RequestStreamFrameFlyweight.encode( + allocator, 1, true, 50, null, Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -203,10 +214,7 @@ void encodeFirstWithDataStream() { void encodeFirstFrameWithMetadata() { ByteBuf rr = RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))); + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -234,7 +242,7 @@ void encodeFirstFrameWithMetadata() { void encodeFirstWithDataAndMetadataStream() { ByteBuf rc = RequestStreamFrameFlyweight.encode( - allocator, 1, true, 50, DefaultPayload.create(data, metadata)); + allocator, 1, true, 50, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -266,7 +274,8 @@ void encodeFirstWithDataAndMetadataStream() { @Test void fragmentData() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); Publisher fragments = FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); @@ -293,11 +302,7 @@ void fragmentData() { void fragmentMetadata() { ByteBuf rr = RequestStreamFrameFlyweight.encode( - allocator, - 1, - true, - 10, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))); + allocator, 1, true, 10, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); Publisher fragments = FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_STREAM, false); @@ -324,7 +329,7 @@ void fragmentMetadata() { void fragmentDataAndMetadata() { ByteBuf rr = RequestResponseFrameFlyweight.encode( - allocator, 1, true, DefaultPayload.create(data, metadata)); + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); Publisher fragments = FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java index 6e0d0dc1b..13632165b 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java @@ -22,7 +22,6 @@ import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; import io.rsocket.frame.*; -import io.rsocket.util.DefaultPayload; import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; @@ -48,15 +47,16 @@ final class FrameReassemblerTest { void reassembleData() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)), + RequestResponseFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)), PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -88,7 +88,8 @@ void reassembleData() { void passthrough() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode(allocator, 1, false, DefaultPayload.create(data))); + RequestResponseFrameFlyweight.encode( + allocator, 1, false, null, Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -115,38 +116,39 @@ void reassembleMetadata() { List byteBufs = Arrays.asList( RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, false, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -184,35 +186,40 @@ void reassembleMetadataChannel() { true, false, 100, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, false, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -249,39 +256,39 @@ void reassembleMetadataStream() { List byteBufs = Arrays.asList( RequestStreamFrameFlyweight.encode( - allocator, - 1, - true, - 250, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + allocator, 1, true, 250, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, false, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -319,34 +326,33 @@ void reassembleMetadataAndData() { List byteBufs = Arrays.asList( RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)), PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -387,32 +393,31 @@ public void cancelBeforeAssembling() { List byteBufs = Arrays.asList( RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); @@ -436,32 +441,31 @@ public void dispose() { List byteBufs = Arrays.asList( RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), PayloadFrameFlyweight.encode( allocator, 1, true, false, true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java new file mode 100644 index 000000000..013e2ebc2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java @@ -0,0 +1,273 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.fragmentation; + +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +final class ReassembleDuplexConnectionTest { + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @DisplayName("reassembles data") + @Test + void reassembleData() { + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameFlyweight.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("reassembles metadata") + @Test + void reassembleMetadata() { + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + false, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestResponseFrameFlyweight.metadata(byteBuf); + Assert.assertEquals(metadata, m); + }) + .verifyComplete(); + } + + @DisplayName("reassembles metadata and data") + @Test + void reassembleMetadataAndData() { + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)), + PayloadFrameFlyweight.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameFlyweight.metadata(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("does not reassemble a non-fragment frame") + @Test + void reassembleNonFragment() { + ByteBuf encode = + RequestResponseFrameFlyweight.encode( + allocator, 1, false, null, Unpooled.wrappedBuffer(data)); + + when(delegate.receive()).thenReturn(Flux.just(encode)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals( + Unpooled.wrappedBuffer(data), RequestResponseFrameFlyweight.data(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("does not reassemble non fragmentable frame") + @Test + void reassembleNonFragmentableFrame() { + ByteBuf encode = CancelFrameFlyweight.encode(allocator, 2); + + when(delegate.receive()).thenReturn(Flux.just(encode)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.CANCEL, FrameHeaderFlyweight.frameType(byteBuf)); + }) + .verifyComplete(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index b22a95c0b..63300c718 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -17,6 +17,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; import org.assertj.core.presentation.StandardRepresentation; public final class ByteBufRepresentation extends StandardRepresentation { @@ -24,7 +25,17 @@ public final class ByteBufRepresentation extends StandardRepresentation { @Override protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { - return ByteBufUtil.prettyHexDump((ByteBuf) object); + try { + String normalBufferString = object.toString(); + String prettyHexDump = ByteBufUtil.prettyHexDump((ByteBuf) object); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } catch (IllegalReferenceCountException e) { + // noops + } } return super.fallbackToStringOf(object); diff --git a/rsocket-core/src/test/java/io/rsocket/frame/DataAndMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/DataAndMetadataFlyweightTest.java deleted file mode 100644 index 6f9113d73..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/DataAndMetadataFlyweightTest.java +++ /dev/null @@ -1,51 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.*; -import org.junit.jupiter.api.Test; - -class DataAndMetadataFlyweightTest { - @Test - void testEncodeData() { - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.PAYLOAD, 0); - ByteBuf data = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm data_"); - ByteBuf frame = DataAndMetadataFlyweight.encodeOnlyData(ByteBufAllocator.DEFAULT, header, data); - ByteBuf d = DataAndMetadataFlyweight.data(frame, false); - String s = ByteBufUtil.prettyHexDump(d); - System.out.println(s); - } - - @Test - void testEncodeMetadata() { - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.PAYLOAD, 0); - ByteBuf data = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm metadata_"); - ByteBuf frame = - DataAndMetadataFlyweight.encodeOnlyMetadata(ByteBufAllocator.DEFAULT, header, data); - ByteBuf d = DataAndMetadataFlyweight.data(frame, false); - String s = ByteBufUtil.prettyHexDump(d); - System.out.println(s); - } - - @Test - void testEncodeDataAndMetadata() { - ByteBuf header = - FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.REQUEST_RESPONSE, 0); - ByteBuf data = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm data_"); - ByteBuf metadata = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm metadata_"); - ByteBuf frame = - DataAndMetadataFlyweight.encode(ByteBufAllocator.DEFAULT, header, metadata, data); - ByteBuf m = DataAndMetadataFlyweight.metadata(frame, true); - String s = ByteBufUtil.prettyHexDump(m); - System.out.println(s); - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - System.out.println(frameType); - - for (int i = 0; i < 10_000_000; i++) { - ByteBuf d1 = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm data_"); - ByteBuf m1 = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm metadata_"); - ByteBuf h1 = - FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.REQUEST_RESPONSE, 0); - ByteBuf f1 = DataAndMetadataFlyweight.encode(ByteBufAllocator.DEFAULT, h1, m1, d1); - f1.release(); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java index e337d4332..eea72c03e 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java @@ -35,7 +35,7 @@ void extensionData() { Assertions.assertFalse(FrameHeaderFlyweight.hasMetadata(extension)); Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(0, ExtensionFrameFlyweight.metadata(extension).readableBytes()); + Assertions.assertNull(ExtensionFrameFlyweight.metadata(extension)); Assertions.assertEquals(data, ExtensionFrameFlyweight.data(extension)); extension.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java index 9ef89326a..439d23c15 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java @@ -15,7 +15,8 @@ public class PayloadFlyweightTest { void nextCompleteDataMetadata() { Payload payload = DefaultPayload.create("d", "md"); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); + PayloadFrameFlyweight.encodeNextCompleteReleasingPayload( + ByteBufAllocator.DEFAULT, 1, payload); String data = PayloadFrameFlyweight.data(nextComplete).toString(StandardCharsets.UTF_8); String metadata = PayloadFrameFlyweight.metadata(nextComplete).toString(StandardCharsets.UTF_8); Assertions.assertEquals("d", data); @@ -27,11 +28,12 @@ void nextCompleteDataMetadata() { void nextCompleteData() { Payload payload = DefaultPayload.create("d"); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); + PayloadFrameFlyweight.encodeNextCompleteReleasingPayload( + ByteBufAllocator.DEFAULT, 1, payload); String data = PayloadFrameFlyweight.data(nextComplete).toString(StandardCharsets.UTF_8); ByteBuf metadata = PayloadFrameFlyweight.metadata(nextComplete); Assertions.assertEquals("d", data); - Assertions.assertTrue(metadata.readableBytes() == 0); + Assertions.assertNull(metadata); nextComplete.release(); } @@ -42,7 +44,8 @@ void nextCompleteMetaData() { Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer("md".getBytes(StandardCharsets.UTF_8))); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); + PayloadFrameFlyweight.encodeNextCompleteReleasingPayload( + ByteBufAllocator.DEFAULT, 1, payload); ByteBuf data = PayloadFrameFlyweight.data(nextComplete); String metadata = PayloadFrameFlyweight.metadata(nextComplete).toString(StandardCharsets.UTF_8); Assertions.assertTrue(data.readableBytes() == 0); @@ -53,7 +56,8 @@ void nextCompleteMetaData() { @Test void nextDataMetadata() { Payload payload = DefaultPayload.create("d", "md"); - ByteBuf next = PayloadFrameFlyweight.encodeNext(ByteBufAllocator.DEFAULT, 1, payload); + ByteBuf next = + PayloadFrameFlyweight.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); String metadata = PayloadFrameFlyweight.metadata(next).toString(StandardCharsets.UTF_8); Assertions.assertEquals("d", data); @@ -64,11 +68,24 @@ void nextDataMetadata() { @Test void nextData() { Payload payload = DefaultPayload.create("d"); - ByteBuf next = PayloadFrameFlyweight.encodeNext(ByteBufAllocator.DEFAULT, 1, payload); + ByteBuf next = + PayloadFrameFlyweight.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); ByteBuf metadata = PayloadFrameFlyweight.metadata(next); Assertions.assertEquals("d", data); - Assertions.assertTrue(metadata.readableBytes() == 0); + Assertions.assertNull(metadata); + next.release(); + } + + @Test + void nextDataEmptyMetadata() { + Payload payload = DefaultPayload.create("d".getBytes(), new byte[0]); + ByteBuf next = + PayloadFrameFlyweight.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameFlyweight.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertEquals(metadata.readableBytes(), 0); next.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java index 9acec2c81..c19d4e1f4 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java @@ -22,8 +22,15 @@ void testEncoding() { Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("000010000000011900000000010000026d6464", ByteBufUtil.hexDump(frame)); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length + // | | | | ⌌Encoded Metadata + // | | | | | ⌌Encoded Data + // __|________|_________|______|____|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "000010000000011900000000010000026d6464"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); frame.release(); } @@ -39,8 +46,14 @@ void testEncodingWithEmptyMetadata() { Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("00000e0000000119000000000100000064", ByteBufUtil.hexDump(frame)); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length (0) + // | | | | ⌌Encoded Data + // __|________|_________|_______|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "00000e0000000119000000000100000064"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); frame.release(); } @@ -57,7 +70,13 @@ void testEncodingWithNullMetadata() { frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - assertEquals("00000b0000000118000000000164", ByteBufUtil.hexDump(frame)); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Data + // __|________|_________|_____| + // ↓<-> ↓↓ <-> ↓↓ <-> ↓↓↓ + String expected = "00000b0000000118000000000164"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); frame.release(); } @@ -96,7 +115,7 @@ void requestResponseData() { assertFalse(FrameHeaderFlyweight.hasMetadata(request)); assertEquals("d", data); - assertTrue(metadata.readableBytes() == 0); + assertNull(metadata); request.release(); } @@ -131,13 +150,13 @@ void requestStreamDataMetadata() { Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); + long actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); String data = RequestStreamFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); String metadata = RequestStreamFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(Integer.MAX_VALUE, actualRequest); + assertEquals(Long.MAX_VALUE, actualRequest); assertEquals("md", metadata); assertEquals("d", data); request.release(); @@ -154,13 +173,13 @@ void requestStreamData() { null, Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); + long actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); String data = RequestStreamFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); ByteBuf metadata = RequestStreamFrameFlyweight.metadata(request); assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(42, actualRequest); - assertTrue(metadata.readableBytes() == 0); + assertEquals(42L, actualRequest); + assertNull(metadata); assertEquals("d", data); request.release(); } @@ -176,13 +195,13 @@ void requestStreamMetadata() { Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), Unpooled.EMPTY_BUFFER); - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); + long actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); ByteBuf data = RequestStreamFrameFlyweight.data(request); String metadata = RequestStreamFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(42, actualRequest); + assertEquals(42L, actualRequest); assertTrue(data.readableBytes() == 0); assertEquals("md", metadata); request.release(); @@ -223,7 +242,7 @@ void requestFnfData() { assertFalse(FrameHeaderFlyweight.hasMetadata(request)); assertEquals("d", data); - assertTrue(metadata.readableBytes() == 0); + assertNull(metadata); request.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java index 8f56608d8..efa962c48 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java @@ -21,8 +21,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.*; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; import java.util.concurrent.atomic.AtomicInteger; @@ -32,14 +33,17 @@ public class ClientServerInputMultiplexerTest { private TestDuplexConnection source; private ClientServerInputMultiplexer clientMultiplexer; - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); private ClientServerInputMultiplexer serverMultiplexer; @Before public void setup() { - source = new TestDuplexConnection(); - clientMultiplexer = new ClientServerInputMultiplexer(source, new PluginRegistry(), true); - serverMultiplexer = new ClientServerInputMultiplexer(source, new PluginRegistry(), false); + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java b/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java deleted file mode 100644 index af4c528e9..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java +++ /dev/null @@ -1,140 +0,0 @@ -package io.rsocket.internal; - -import static org.junit.jupiter.api.Assertions.*; - -import java.time.Duration; -import java.util.concurrent.ThreadLocalRandom; -import java.util.function.Consumer; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; - -class RateLimitableRequestPublisherTest { - - @Test - public void testThatRequest1WillBePropagatedUpstream() { - Flux source = - Flux.just(1) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(1)) - .expectNext(1) - .expectComplete() - .verify(Duration.ofMillis(1000)); - } - - @Test - public void testThatRequest256WillBePropagatedToUpstreamWithLimitedRate() { - Flux source = - Flux.range(0, 256) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(256)) - .expectNextCount(256) - .expectComplete() - .verify(Duration.ofMillis(1000)); - } - - @Test - public void testThatRequest256WillBePropagatedToUpstreamWithLimitedRateInFewSteps() { - Flux source = - Flux.range(0, 256) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(10)) - .expectNextCount(5) - .then(() -> rateLimitableRequestPublisher.request(128)) - .expectNextCount(133) - .expectNoEvent(Duration.ofMillis(10)) - .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) - .expectNextCount(118) - .expectComplete() - .verify(Duration.ofMillis(1000)); - } - - @Test - public void testThatRequestInRandomFashionWillBePropagatedToUpstreamWithLimitedRateInFewSteps() { - Flux source = - Flux.range(0, 10000000) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then( - () -> - Flux.interval(Duration.ofMillis(1000)) - .onBackpressureDrop() - .subscribe( - new Consumer() { - int count = 10000000; - - @Override - public void accept(Long __) { - int random = ThreadLocalRandom.current().nextInt(1, 512); - - long request = Math.min(random, count); - - count -= request; - - rateLimitableRequestPublisher.request(count); - } - })) - .expectNextCount(10000000) - .expectComplete() - .verify(Duration.ofMillis(30000)); - } - - @Test - public void testThatRequestLongMaxValueWillBeDeliveredInSeparateChunks() { - Flux source = - Flux.range(0, 10000000) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) - .expectNextCount(10000000) - .expectComplete() - .verify(Duration.ofMillis(30000)); - } - - @Test - public void testThatRequestLongMaxWithIntegerMaxValuePrefetchWillBeDeliveredAsLongMaxValue() { - Flux source = - Flux.range(0, 10000000) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isEqualTo(Long.MAX_VALUE)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, Integer.MAX_VALUE); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) - .expectNextCount(10000000) - .expectComplete() - .verify(Duration.ofMillis(30000)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java index 0dc7d9090..7bf975543 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -17,6 +17,7 @@ package io.rsocket.internal; import io.rsocket.Payload; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.EmptyPayload; import java.util.concurrent.CountDownLatch; import org.junit.Assert; @@ -82,6 +83,36 @@ public void testOnNextAfterSubscribe_1000() throws Exception { testOnNextAfterSubscribeN(1000); } + @Test + public void testPrioritizedSending() { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < 1000; i++) { + processor.onNext(EmptyPayload.INSTANCE); + } + + processor.onNextPrioritized(ByteBufPayload.create("test")); + + Payload closestPayload = processor.next().block(); + + Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + } + + @Test + public void testPrioritizedFused() { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < 1000; i++) { + processor.onNext(EmptyPayload.INSTANCE); + } + + processor.onNextPrioritized(ByteBufPayload.create("test")); + + Payload closestPayload = processor.poll(); + + Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + } + public void testOnNextAfterSubscribeN(int n) throws Exception { CountDownLatch latch = new CountDownLatch(n); UnboundedProcessor processor = new UnboundedProcessor<>(); diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java index d945dd45d..58323c066 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import org.reactivestreams.Publisher; import reactor.core.publisher.DirectProcessor; @@ -25,17 +26,22 @@ import reactor.core.publisher.MonoProcessor; public class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; private final DirectProcessor send; private final DirectProcessor receive; private final MonoProcessor onClose; private final String name; public LocalDuplexConnection( - String name, DirectProcessor send, DirectProcessor receive) { + String name, + ByteBufAllocator allocator, + DirectProcessor send, + DirectProcessor receive) { this.name = name; + this.allocator = allocator; this.send = send; this.receive = receive; - onClose = MonoProcessor.create(); + this.onClose = MonoProcessor.create(); } @Override @@ -52,6 +58,11 @@ public Flux receive() { return receive.doOnNext(f -> System.out.println(name + " - " + f.toString())); } + @Override + public ByteBufAllocator alloc() { + return allocator; + } + @Override public void dispose() { onClose.onComplete(); diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java index 37ad8ee5b..a30e75875 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -1,12 +1,15 @@ package io.rsocket.test.util; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ClientTransport; import reactor.core.publisher.Mono; public class TestClientTransport implements ClientTransport { - - private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(allocator); @Override public Mono connect(int mtu) { @@ -16,4 +19,8 @@ public Mono connect(int mtu) { public TestDuplexConnection testConnection() { return testDuplexConnection; } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index 6298b0c3a..17a19b8c9 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; @@ -46,17 +47,19 @@ public class TestDuplexConnection implements DuplexConnection { private final FluxSink receivedSink; private final MonoProcessor onClose; private final ConcurrentLinkedQueue> sendSubscribers; + private final ByteBufAllocator allocator; private volatile double availability = 1; private volatile int initialSendRequestN = Integer.MAX_VALUE; - public TestDuplexConnection() { - sent = new LinkedBlockingQueue<>(); - received = DirectProcessor.create(); - receivedSink = received.sink(); - sentPublisher = DirectProcessor.create(); - sendSink = sentPublisher.sink(); - sendSubscribers = new ConcurrentLinkedQueue<>(); - onClose = MonoProcessor.create(); + public TestDuplexConnection(ByteBufAllocator allocator) { + this.allocator = allocator; + this.sent = new LinkedBlockingQueue<>(); + this.received = DirectProcessor.create(); + this.receivedSink = received.sink(); + this.sentPublisher = DirectProcessor.create(); + this.sendSink = sentPublisher.sink(); + this.sendSubscribers = new ConcurrentLinkedQueue<>(); + this.onClose = MonoProcessor.create(); } @Override @@ -83,6 +86,11 @@ public Flux receive() { return received; } + @Override + public ByteBufAllocator alloc() { + return allocator; + } + @Override public double availability() { return availability; diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java index 5cebf0da1..325496148 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -1,12 +1,16 @@ package io.rsocket.test.util; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ServerTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; public class TestServerTransport implements ServerTransport { private final MonoProcessor conn = MonoProcessor.create(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @Override public Mono start(ConnectionAcceptor acceptor, int mtu) { @@ -39,8 +43,12 @@ private void disposeConnection() { } public TestDuplexConnection connect() { - TestDuplexConnection c = new TestDuplexConnection(); + TestDuplexConnection c = new TestDuplexConnection(allocator); conn.onNext(c); return c; } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } } diff --git a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java b/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java deleted file mode 100644 index 526757fbe..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; -import reactor.core.publisher.Mono; - -public final class TestUriHandler implements UriHandler { - - private static final String SCHEME = "test"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of((mtu) -> Mono.just(new TestDuplexConnection())); - } - - @Override - public Optional buildServer(URI uri) { - return Optional.empty(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java b/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java deleted file mode 100644 index 7aeef708f..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static org.junit.Assert.assertTrue; - -import io.rsocket.DuplexConnection; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import org.junit.Test; - -public class UriTransportRegistryTest { - @Test - public void testTestRegistered() { - ClientTransport test = UriTransportRegistry.clientForUri("test://test"); - - DuplexConnection duplexConnection = test.connect(0).block(); - - assertTrue(duplexConnection instanceof TestDuplexConnection); - } - - @Test(expected = UnsupportedOperationException.class) - public void testTestUnregistered() { - ClientTransport test = UriTransportRegistry.clientForUri("mailto://bonson@baulsupp.net"); - - test.connect(0).block(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java new file mode 100644 index 000000000..2ad944d09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java @@ -0,0 +1,64 @@ +package io.rsocket.util; + +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ByteBufPayloadTest { + + @Test + public void shouldIndicateThatItHasMetadata() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = ByteBufPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + ByteBufPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldThrowExceptionIfAccessAfterRelease() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.release()).isTrue(); + + Assertions.assertThatThrownBy(payload::hasMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::data).isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::metadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::touch) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(() -> payload.touch("test")) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getDataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java index 45ee4eacb..6bae0886b 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -16,10 +16,13 @@ package io.rsocket.util; -import static org.hamcrest.MatcherAssert.*; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import io.netty.buffer.Unpooled; import io.rsocket.Payload; +import java.nio.ByteBuffer; +import org.assertj.core.api.Assertions; import org.junit.Test; public class DefaultPayloadTest { @@ -48,4 +51,27 @@ public void staticMethods() { assertDataAndMetadata(DefaultPayload.create(DATA_VAL, METADATA_VAL), DATA_VAL, METADATA_VAL); assertDataAndMetadata(DefaultPayload.create(DATA_VAL), DATA_VAL, null); } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = DefaultPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata2() { + Payload payload = + DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + } } diff --git a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 068667aa7..000000000 --- a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.uri.TestUriHandler diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle index 5f63b0761..01e80cfa1 100644 --- a/rsocket-examples/build.gradle +++ b/rsocket-examples/build.gradle @@ -22,13 +22,13 @@ dependencies { implementation project(':rsocket-core') implementation project(':rsocket-transport-local') implementation project(':rsocket-transport-netty') + runtimeOnly 'ch.qos.logback:logback-classic' testImplementation project(':rsocket-test') testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.mockito:mockito-core' testImplementation 'org.assertj:assertj-core' testImplementation 'io.projectreactor:reactor-test' - testImplementation 'ch.qos.logback:logback-classic' // TODO: Remove after JUnit5 migration testCompileOnly 'junit:junit' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java index ac889ecfc..71e48790f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,67 +20,53 @@ import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; -import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.transport.local.LocalClientTransport; -import io.rsocket.transport.local.LocalServerTransport; -import io.rsocket.util.ByteBufPayload; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; import java.time.Duration; import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; public final class ChannelEchoClient { - static final Payload payload1 = ByteBufPayload.create("Hello "); + + private static final Logger logger = LoggerFactory.getLogger(ChannelEchoClient.class); public static void main(String[] args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new SocketAcceptorImpl()) - .transport(LocalServerTransport.create("localhost")) - .start() + RSocketServer.create(new EchoAcceptor()) + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() - .keepAliveAckTimeout(Duration.ofMinutes(10)) - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(LocalClientTransport.create("localhost")) - .start() - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); - Flux.range(0, 100000000) - .concatMap(i -> socket.fireAndForget(payload1.retain())) - // .doOnNext(p -> { - //// System.out.println(p.getDataUtf8()); - // p.release(); - // }) - .blockLast(); + socket + .requestChannel( + Flux.interval(Duration.ofMillis(1000)).map(i -> DefaultPayload.create("Hello"))) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .doFinally(signalType -> socket.dispose()) + .then() + .block(); } - private static class SocketAcceptorImpl implements SocketAcceptor { + private static class EchoAcceptor implements SocketAcceptor { @Override public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { return Mono.just( new AbstractRSocket() { - - @Override - public Mono fireAndForget(Payload payload) { - // System.out.println(payload.getDataUtf8()); - payload.release(); - return Mono.empty(); - } - - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - @Override public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads).subscribeOn(Schedulers.single()); + return Flux.from(payloads) + .map(Payload::getDataUtf8) + .map(s -> "Echo: " + s) + .map(DefaultPayload::create); } }); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java index c0a271d66..bfa58bf40 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,8 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; @@ -30,10 +31,9 @@ public final class DuplexClient { public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setup, reactiveSocket) -> { - reactiveSocket + RSocketServer.create( + (setup, rsocket) -> { + rsocket .requestStream(DefaultPayload.create("Hello-Bidi")) .map(Payload::getDataUtf8) .log() @@ -41,23 +41,22 @@ public static void main(String[] args) { return Mono.just(new AbstractRSocket() {}); }) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() + RSocketConnector.create() .acceptor( - rSocket -> - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofSeconds(1)) - .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)); - } - }) - .transport(TcpClientTransport.create("localhost", 7000)) - .start() + (setup, rsocket) -> + Mono.just( + new AbstractRSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.interval(Duration.ofSeconds(1)) + .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)); + } + })) + .connect(TcpClientTransport.create("localhost", 7000)) .block(); socket.onClose().block(); diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java index 7482c7d1a..a12c9a170 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java @@ -1,3 +1,19 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.examples.transport.tcp.lease; import static java.time.Duration.ofSeconds; @@ -5,7 +21,8 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.lease.Lease; import io.rsocket.lease.LeaseStats; import io.rsocket.lease.Leases; @@ -27,28 +44,26 @@ public class LeaseExample { public static void main(String[] args) { CloseableChannel server = - RSocketFactory.receive() + RSocketServer.create( + (setup, sendingRSocket) -> Mono.just(new ServerAcceptor(sendingRSocket))) .lease( () -> Leases.create() .sender(new LeaseSender(SERVER_TAG, 7_000, 5)) .receiver(new LeaseReceiver(SERVER_TAG)) .stats(new NoopStats())) - .acceptor((setup, sendingRSocket) -> Mono.just(new ServerAcceptor(sendingRSocket))) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + .bind(TcpServerTransport.create("localhost", 7000)) .block(); RSocket clientRSocket = - RSocketFactory.connect() + RSocketConnector.create() .lease( () -> Leases.create() .sender(new LeaseSender(CLIENT_TAG, 3_000, 5)) .receiver(new LeaseReceiver(CLIENT_TAG))) - .acceptor(rSocket -> new ClientAcceptor()) - .transport(TcpClientTransport.create(server.address())) - .start() + .acceptor((rSocket, setup) -> Mono.just(new ClientAcceptor())) + .connect(TcpClientTransport.create(server.address())) .block(); Flux.interval(ofSeconds(1)) diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java index 537485fa4..1b9994c2f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,17 +19,21 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; public final class HelloWorldClient { + private static final Logger logger = LoggerFactory.getLogger(HelloWorldClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( + RSocketServer.create( (setupPayload, reactiveSocket) -> Mono.just( new AbstractRSocket() { @@ -39,42 +43,26 @@ public static void main(String[] args) { public Mono requestResponse(Payload p) { if (fail) { fail = false; - return Mono.error(new Throwable()); + return Mono.error(new Throwable("Simulated error")); } else { return Mono.just(p); } } })) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); - - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); - - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + for (int i = 0; i < 3; i++) { + socket + .requestResponse(DefaultPayload.create("Hello")) + .map(Payload::getDataUtf8) + .onErrorReturn("error") + .doOnNext(logger::debug) + .block(); + } socket.dispose(); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java index e6867f8b5..6724ca93f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java @@ -3,13 +3,21 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import java.io.*; +import java.io.BufferedInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.SynchronousSink; class Files { + private static final Logger logger = LoggerFactory.getLogger(Files.class); public static Flux fileSource(String fileName, int chunkSizeBytes) { return Flux.generate( @@ -35,8 +43,7 @@ public void onNext(Payload payload) { ByteBuf data = payload.data(); receivedBytes += data.readableBytes(); receivedCount += 1; - System.out.println( - "Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); + logger.debug("Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); if (outputStream == null) { outputStream = open(fileName); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java index ca115d281..d449dd205 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -1,45 +1,65 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.examples.transport.tcp.resume; import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; -import io.rsocket.resume.ClientResume; -import io.rsocket.resume.PeriodicResumeStrategy; -import io.rsocket.resume.ResumeStrategy; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; public class ResumeFileTransfer { + /*amount of file chunks requested by subscriber: n, refilled on n/2 of received items*/ private static final int PREFETCH_WINDOW_SIZE = 4; + private static final Logger logger = LoggerFactory.getLogger(ResumeFileTransfer.class); public static void main(String[] args) { RequestCodec requestCodec = new RequestCodec(); + Resume resume = + new Resume() + .sessionDuration(Duration.ofMinutes(5)) + .retry( + Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)) + .doBeforeRetry( + retrySignal -> + logger.debug("Disconnected. Trying to resume connection..."))); CloseableChannel server = - RSocketFactory.receive() - .resume() - .resumeSessionDuration(Duration.ofMinutes(5)) - .acceptor((setup, rSocket) -> Mono.just(new FileServer(requestCodec))) - .transport(TcpServerTransport.create("localhost", 8000)) - .start() + RSocketServer.create((setup, rSocket) -> Mono.just(new FileServer(requestCodec))) + .resume(resume) + .bind(TcpServerTransport.create("localhost", 8000)) .block(); RSocket client = - RSocketFactory.connect() - .resume() - .resumeStrategy( - () -> new VerboseResumeStrategy(new PeriodicResumeStrategy(Duration.ofSeconds(1)))) - .resumeSessionDuration(Duration.ofMinutes(5)) - .transport(TcpClientTransport.create("localhost", 8001)) - .start() + RSocketConnector.create() + .resume(resume) + .connect(TcpClientTransport.create("localhost", 8001)) .block(); client @@ -72,20 +92,6 @@ public Flux requestStream(Payload payload) { } } - private static class VerboseResumeStrategy implements ResumeStrategy { - private final ResumeStrategy resumeStrategy; - - public VerboseResumeStrategy(ResumeStrategy resumeStrategy) { - this.resumeStrategy = resumeStrategy; - } - - @Override - public Publisher apply(ClientResume clientResume, Throwable throwable) { - return Flux.from(resumeStrategy.apply(clientResume, throwable)) - .doOnNext(v -> System.out.println("Disconnected. Trying to resume connection...")); - } - } - private static class RequestCodec { public Payload encode(Request request) { diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java index 57a659c1d..1ef2b7a90 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,33 +16,38 @@ package io.rsocket.examples.transport.tcp.stream; -import io.rsocket.*; +import io.rsocket.AbstractRSocket; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public final class StreamingClient { + private static final Logger logger = LoggerFactory.getLogger(StreamingClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl()) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + RSocketServer.create(new SocketAcceptorImpl()) + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); socket .requestStream(DefaultPayload.create("Hello")) .map(Payload::getDataUtf8) - .doOnNext(System.out::println) + .doOnNext(logger::debug) .take(10) .then() .doFinally(signalType -> socket.dispose()) diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java index d3865c01b..24f029845 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,8 +22,10 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; @@ -45,10 +47,10 @@ public class WebSocketHeadersSample { public static void main(String[] args) { ServerTransport.ConnectionAcceptor acceptor = - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) + RSocketServer.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) .acceptor(new SocketAcceptorImpl()) - .toConnectionAcceptor(); + .asConnectionAcceptor(); DisposableServer disposableServer = HttpServer.create() @@ -61,7 +63,8 @@ public static void main(String[] args) { (in, out) -> { if (in.headers().containsValue("Authorization", "test", true)) { DuplexConnection connection = - new WebsocketDuplexConnection((Connection) in); + new ReassemblyDuplexConnection( + new WebsocketDuplexConnection((Connection) in), false); return acceptor.apply(connection).then(out.neverComplete()); } @@ -82,11 +85,10 @@ public static void main(String[] args) { }); RSocket socket = - RSocketFactory.connect() - .keepAliveAckTimeout(Duration.ofMinutes(10)) - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(clientTransport) - .start() + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport) .block(); Flux.range(0, 100) @@ -102,11 +104,10 @@ public static void main(String[] args) { WebsocketClientTransport.create(disposableServer.host(), disposableServer.port()); RSocket rSocket = - RSocketFactory.connect() - .keepAliveAckTimeout(Duration.ofMinutes(10)) - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(clientTransport2) - .start() + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport2) .block(); // expect error here because of closed channel diff --git a/rsocket-examples/src/main/resources/log4j.properties b/rsocket-examples/src/main/resources/log4j.properties deleted file mode 100644 index 035f18ebd..000000000 --- a/rsocket-examples/src/main/resources/log4j.properties +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=DEBUG, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n \ No newline at end of file diff --git a/rsocket-examples/src/main/resources/logback.xml b/rsocket-examples/src/main/resources/logback.xml new file mode 100644 index 000000000..17dd8b5e3 --- /dev/null +++ b/rsocket-examples/src/main/resources/logback.xml @@ -0,0 +1,35 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + + + diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index 19c29061b..1ef7771cd 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -26,7 +26,8 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.RSocketInterceptor; import io.rsocket.plugins.SocketAcceptorInterceptor; @@ -39,7 +40,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.reactivestreams.Publisher; @@ -49,19 +49,20 @@ public class IntegrationTest { - private static final RSocketInterceptor requesterPlugin; - private static final RSocketInterceptor responderPlugin; - private static final SocketAcceptorInterceptor clientAcceptorPlugin; - private static final SocketAcceptorInterceptor serverAcceptorPlugin; - private static final DuplexConnectionInterceptor connectionPlugin; - public static volatile boolean calledRequester = false; - public static volatile boolean calledResponder = false; - public static volatile boolean calledClientAcceptor = false; - public static volatile boolean calledServerAcceptor = false; - public static volatile boolean calledFrame = false; + private static final RSocketInterceptor requesterInterceptor; + private static final RSocketInterceptor responderInterceptor; + private static final SocketAcceptorInterceptor clientAcceptorInterceptor; + private static final SocketAcceptorInterceptor serverAcceptorInterceptor; + private static final DuplexConnectionInterceptor connectionInterceptor; + + private static volatile boolean calledRequester = false; + private static volatile boolean calledResponder = false; + private static volatile boolean calledClientAcceptor = false; + private static volatile boolean calledServerAcceptor = false; + private static volatile boolean calledFrame = false; static { - requesterPlugin = + requesterInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override @@ -71,7 +72,7 @@ public Mono requestResponse(Payload payload) { } }; - responderPlugin = + responderInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override @@ -81,21 +82,21 @@ public Mono requestResponse(Payload payload) { } }; - clientAcceptorPlugin = + clientAcceptorInterceptor = acceptor -> (setup, sendingSocket) -> { calledClientAcceptor = true; return acceptor.accept(setup, sendingSocket); }; - serverAcceptorPlugin = + serverAcceptorInterceptor = acceptor -> (setup, sendingSocket) -> { calledServerAcceptor = true; return acceptor.accept(setup, sendingSocket); }; - connectionPlugin = + connectionInterceptor = (type, connection) -> { calledFrame = true; return connection; @@ -114,18 +115,8 @@ public void startup() { requestCount = new AtomicInteger(); disconnectionCounter = new CountDownLatch(1); - TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); - server = - RSocketFactory.receive() - .addResponderPlugin(responderPlugin) - .addSocketAcceptorPlugin(serverAcceptorPlugin) - .addConnectionPlugin(connectionPlugin) - .errorConsumer( - t -> { - errorCount.incrementAndGet(); - }) - .acceptor( + RSocketServer.create( (setup, sendingSocket) -> { sendingSocket .onClose() @@ -152,17 +143,24 @@ public Flux requestChannel(Publisher payloads) { } }); }) - .transport(serverTransport) - .start() + .interceptors( + registry -> + registry + .forResponder(responderInterceptor) + .forSocketAcceptor(serverAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) .block(); client = - RSocketFactory.connect() - .addRequesterPlugin(requesterPlugin) - .addSocketAcceptorPlugin(clientAcceptorPlugin) - .addConnectionPlugin(connectionPlugin) - .transport(TcpClientTransport.create(server.address())) - .start() + RSocketConnector.create() + .interceptors( + registry -> + registry + .forRequester(requesterInterceptor) + .forSocketAcceptor(clientAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .connect(TcpClientTransport.create(server.address())) .block(); } @@ -204,8 +202,6 @@ public void testCallRequestWithErrorAndThenRequest() { } catch (Throwable t) { } - Assert.assertEquals(1, errorCount.incrementAndGet()); - testRequest(); } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index 7a30a7fd1..d24083ea6 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -3,7 +3,8 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.test.SlowTest; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; @@ -21,25 +22,20 @@ public class InteractionsLoadTest { @Test @SlowTest public void channel() { - TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); - CloseableChannel server = - RSocketFactory.receive() - .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket())) - .transport(serverTransport) - .start() + RSocketServer.create((setup, rsocket) -> Mono.just(new EchoRSocket())) + .bind(TcpServerTransport.create("localhost", 0)) .block(Duration.ofSeconds(10)); - TcpClientTransport transport = TcpClientTransport.create(server.address()); - - RSocket client = - RSocketFactory.connect().transport(transport).start().block(Duration.ofSeconds(10)); + RSocket clientRSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .block(Duration.ofSeconds(10)); int concurrency = 16; Flux.range(1, concurrency) .flatMap( v -> - client + clientRSocket .requestChannel( input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo"))) .limitRate(10000), diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index 9e7f5b0a7..7133820ca 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -22,7 +22,8 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -46,20 +47,14 @@ public class TcpIntegrationTest { @Before public void startup() { - TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) - .transport(serverTransport) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .bind(TcpServerTransport.create("localhost", 0)) .block(); } private RSocket buildClient() { - return RSocketFactory.connect() - .transport(TcpClientTransport.create(server.address())) - .start() - .block(); + return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); } @After diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java index ec1d41bf9..8fe09430a 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,60 +16,52 @@ package io.rsocket.integration; -import io.rsocket.*; +import io.rsocket.AbstractRSocket; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; import org.junit.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class TestingStreaming { - private Supplier> serverSupplier = - () -> LocalServerTransport.create("test"); - - private Supplier clientSupplier = () -> LocalClientTransport.create("test"); + LocalServerTransport serverTransport = LocalServerTransport.create("test"); @Test(expected = ApplicationErrorException.class) public void testRangeButThrowException() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .doOnNext( - i -> { - if (i > 3) { - throw new RuntimeException("BOOM!"); - } - }) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + (connectionSetupPayload, rSocket) -> + Mono.just( + new AbstractRSocket() { + @Override + public double availability() { + return 1.0; + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 1000) + .doOnNext( + i -> { + if (i > 3) { + throw new RuntimeException("BOOM!"); + } + }) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class); + } + })) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); @@ -85,29 +77,23 @@ public void testRangeOfConsumers() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + (connectionSetupPayload, rSocket) -> + Mono.just( + new AbstractRSocket() { + @Override + public double availability() { + return 1.0; + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 1000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class); + } + })) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); @@ -119,10 +105,7 @@ public Flux requestStream(Payload payload) { } private Flux consumer(String s) { - return RSocketFactory.connect() - .errorConsumer(Throwable::printStackTrace) - .transport(clientSupplier) - .start() + return RSocketConnector.connectWith(LocalClientTransport.create("test")) .flatMapMany( rSocket -> { AtomicInteger count = new AtomicInteger(); @@ -135,31 +118,25 @@ private Flux consumer(String s) { @Test public void testSingleConsumer() { Closeable server = null; - try { server = - RSocketFactory.receive() - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + (connectionSetupPayload, rSocket) -> + Mono.just( + new AbstractRSocket() { + @Override + public double availability() { + return 1.0; + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10_000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class); + } + })) + .bind(serverTransport) .block(); consumer("1").blockLast(); diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java index 009d0d8db..bd2db39c7 100644 --- a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.exceptions.RejectedResumeException; import io.rsocket.exceptions.UnsupportedSetupException; import io.rsocket.test.SlowTest; @@ -33,15 +35,14 @@ import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.ReplayProcessor; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; +import reactor.util.retry.Retry; @SlowTest public class ResumeIntegrationTest { @@ -103,11 +104,9 @@ public void reconnectOnMissingSession() { DisconnectableClientTransport clientTransport = new DisconnectableClientTransport(clientTransport(closeable.address())); - ErrorConsumer errorConsumer = new ErrorConsumer(); int clientSessionDurationSeconds = 10; - RSocket rSocket = - newClientRSocket(clientTransport, clientSessionDurationSeconds, errorConsumer).block(); + RSocket rSocket = newClientRSocket(clientTransport, clientSessionDurationSeconds).block(); Mono.delay(Duration.ofSeconds(1)) .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(3))); @@ -117,43 +116,34 @@ public void reconnectOnMissingSession() { .expectError() .verify(Duration.ofSeconds(5)); - StepVerifier.create(errorConsumer.errors().next()) - .expectNextMatches( + StepVerifier.create(rSocket.onClose()) + .expectErrorMatches( err -> err instanceof RejectedResumeException && "unknown resume token".equals(err.getMessage())) - .expectComplete() .verify(Duration.ofSeconds(5)); } @Test void serverMissingResume() { CloseableChannel closeableChannel = - RSocketFactory.receive() - .acceptor((setupPayload, rSocket) -> Mono.just(new TestResponderRSocket())) - .transport(serverTransport(SERVER_HOST, SERVER_PORT)) - .start() + RSocketServer.create((setupPayload, rSocket) -> Mono.just(new TestResponderRSocket())) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)) .block(); - ErrorConsumer errorConsumer = new ErrorConsumer(); - RSocket rSocket = - RSocketFactory.connect() - .resume() - .errorConsumer(errorConsumer) - .transport(clientTransport(closeableChannel.address())) - .start() + RSocketConnector.create() + .resume(new Resume()) + .connect(clientTransport(closeableChannel.address())) .block(); - StepVerifier.create(errorConsumer.errors().next().doFinally(s -> closeableChannel.dispose())) - .expectNextMatches( + StepVerifier.create(rSocket.onClose().doFinally(s -> closeableChannel.dispose())) + .expectErrorMatches( err -> err instanceof UnsupportedSetupException && "resume not supported".equals(err.getMessage())) - .expectComplete() .verify(Duration.ofSeconds(5)); - StepVerifier.create(rSocket.onClose()).expectComplete().verify(Duration.ofSeconds(5)); Assertions.assertThat(rSocket.isDisposed()).isTrue(); } @@ -165,21 +155,8 @@ static ServerTransport serverTransport(String host, int port) return TcpServerTransport.create(host, port); } - private static class ErrorConsumer implements Consumer { - private final ReplayProcessor errors = ReplayProcessor.create(); - - public Flux errors() { - return errors; - } - - @Override - public void accept(Throwable throwable) { - errors.onNext(throwable); - } - } - private static Flux testRequest() { - return Flux.interval(Duration.ofMillis(50)) + return Flux.interval(Duration.ofMillis(500)) .map(v -> DefaultPayload.create("client_request")) .onBackpressureDrop(); } @@ -201,24 +178,15 @@ private void throwOnNonContinuous(AtomicInteger counter, String x) { private static Mono newClientRSocket( DisconnectableClientTransport clientTransport, int sessionDurationSeconds) { - return newClientRSocket(clientTransport, sessionDurationSeconds, err -> {}); - } - - private static Mono newClientRSocket( - DisconnectableClientTransport clientTransport, - int sessionDurationSeconds, - Consumer errConsumer) { - return RSocketFactory.connect() - .resume() - .resumeSessionDuration(Duration.ofSeconds(sessionDurationSeconds)) - .resumeStore(t -> new InMemoryResumableFramesStore("client", 500_000)) - .resumeCleanupOnKeepAlive() - .keepAliveTickPeriod(Duration.ofSeconds(5)) - .keepAliveAckTimeout(Duration.ofMinutes(5)) - .errorConsumer(errConsumer) - .resumeStrategy(() -> new PeriodicResumeStrategy(Duration.ofSeconds(1))) - .transport(clientTransport) - .start(); + return RSocketConnector.create() + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", 500_000)) + .cleanupStoreOnKeepAlive() + .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) + .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) + .connect(clientTransport); } private static Mono newServerRSocket() { @@ -226,14 +194,13 @@ private static Mono newServerRSocket() { } private static Mono newServerRSocket(int sessionDurationSeconds) { - return RSocketFactory.receive() - .resume() - .resumeStore(t -> new InMemoryResumableFramesStore("server", 500_000)) - .resumeSessionDuration(Duration.ofSeconds(sessionDurationSeconds)) - .resumeCleanupOnKeepAlive() - .acceptor((setupPayload, rSocket) -> Mono.just(new TestResponderRSocket())) - .transport(serverTransport(SERVER_HOST, SERVER_PORT)) - .start(); + return RSocketServer.create((setup, rsocket) -> Mono.just(new TestResponderRSocket())) + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .cleanupStoreOnKeepAlive() + .storeFactory(t -> new InMemoryResumableFramesStore("server", 500_000))) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)); } private static class TestResponderRSocket extends AbstractRSocket { diff --git a/rsocket-examples/src/test/resources/log4j.properties b/rsocket-examples/src/test/resources/log4j.properties deleted file mode 100644 index 51731fc15..000000000 --- a/rsocket-examples/src/test/resources/log4j.properties +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss,SSS} %5p [%t] (%F) - %m%n -#log4j.logger.io.rsocket.FrameLogger=Debug \ No newline at end of file diff --git a/rsocket-examples/src/test/resources/logback-test.xml b/rsocket-examples/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-examples/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle index a2c8b73c7..748f95de6 100644 --- a/rsocket-load-balancer/build.gradle +++ b/rsocket-load-balancer/build.gradle @@ -34,6 +34,7 @@ dependencies { testCompileOnly 'junit:junit' testImplementation 'org.hamcrest:hamcrest-library' testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'ch.qos.logback:logback-classic' } description = 'Transparent Load Balancer for RSocket' diff --git a/rsocket-load-balancer/src/test/resources/log4j.properties b/rsocket-load-balancer/src/test/resources/log4j.properties deleted file mode 100644 index 8fc3a9cdd..000000000 --- a/rsocket-load-balancer/src/test/resources/log4j.properties +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n \ No newline at end of file diff --git a/rsocket-load-balancer/src/test/resources/logback-test.xml b/rsocket-load-balancer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-load-balancer/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java index 20d58dcb7..9904c2b24 100644 --- a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -20,6 +20,7 @@ import io.micrometer.core.instrument.*; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameType; @@ -82,6 +83,11 @@ final class MicrometerDuplexConnection implements DuplexConnection { this.frameCounters = new FrameCounters(connectionType, meterRegistry, tags); } + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + @Override public void dispose() { delegate.dispose(); diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java index ec143b7ab..6f562875f 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,8 @@ import io.rsocket.Closeable; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.function.BiFunction; @@ -47,17 +48,13 @@ public ClientSetupRule( this.serverInit = address -> - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, rsocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); this.clientConnector = (address, server) -> - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java index 2651b14ec..60ff05124 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java @@ -87,12 +87,12 @@ public static ByteBuf createTestRequestNFrame() { /** @return {@link ByteBuf} representing test instance of Request-Response frame */ public static ByteBuf createTestRequestResponseFrame() { - return RequestResponseFrameFlyweight.encode(allocator, 1, false, emptyPayload); + return RequestResponseFrameFlyweight.encodeReleasingPayload(allocator, 1, emptyPayload); } /** @return {@link ByteBuf} representing test instance of Request-Stream frame */ public static ByteBuf createTestRequestStreamFrame() { - return RequestStreamFrameFlyweight.encode(allocator, 1, false, 1L, emptyPayload); + return RequestStreamFrameFlyweight.encodeReleasingPayload(allocator, 1, 1L, emptyPayload); } /** @return {@link ByteBuf} representing test instance of Setup frame */ diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index 57a2e5c3c..26163d3a6 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -55,6 +55,6 @@ public Mono fireAndForget(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { // TODO is defensive copy neccesary? - return Flux.from(payloads).map(DefaultPayload::create); + return Flux.from(payloads).map(Payload::retain); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index fc6301d7d..583f58634 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,28 +19,62 @@ import io.rsocket.Closeable; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; +import java.io.BufferedReader; +import java.io.InputStreamReader; import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.Disposable; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; public interface TransportTest { + String MOCK_DATA = "test-data"; + String MOCK_METADATA = "metadata"; + String LARGE_DATA = read("words.shakespeare.txt.gz"); + Payload LARGE_PAYLOAD = DefaultPayload.create(LARGE_DATA, LARGE_DATA); + + static String read(String resourceName) { + + try (BufferedReader br = + new BufferedReader( + new InputStreamReader( + new GZIPInputStream( + TransportTest.class.getClassLoader().getResourceAsStream(resourceName))))) { + + return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r")); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @BeforeEach + default void setUp() { + Hooks.onOperatorDebug(); + } + @AfterEach default void close() { getTransportPair().dispose(); + Hooks.resetOnOperatorDebug(); } default Payload createTestPayload(int metadataPresent) { @@ -54,12 +88,12 @@ default Payload createTestPayload(int metadataPresent) { metadata1 = ""; break; default: - metadata1 = "metadata"; + metadata1 = MOCK_METADATA; break; } String metadata = metadata1; - return DefaultPayload.create("test-data", metadata); + return DefaultPayload.create(MOCK_DATA, metadata); } @DisplayName("makes 10 fireAndForget requests") @@ -73,6 +107,17 @@ default void fireAndForget10() { .verify(getTimeout()); } + @DisplayName("makes 10 fireAndForget with Large Payload in Requests") + @Test + default void largePayloadFireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD)) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + default RSocket getClient() { return getTransportPair().getClient(); } @@ -92,6 +137,17 @@ default void metadataPush10() { .verify(getTimeout()); } + @DisplayName("makes 10 metadataPush with Large Metadata in requests") + @Test + default void largePayloadMetadataPush10() { + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", LARGE_DATA))) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 1 requestChannel request with 0 payloads") @Test default void requestChannel0() { @@ -127,6 +183,19 @@ default void requestChannel200_000() { .verify(getTimeout()); } + @DisplayName("makes 1 requestChannel request with 200 large payloads") + @Test + default void largePayloadRequestChannel200() { + Flux payloads = Flux.range(0, 200).map(__ -> LARGE_PAYLOAD); + + getClient() + .requestChannel(payloads) + .as(StepVerifier::create) + .expectNextCount(200) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 1 requestChannel request with 20,000 payloads") @Test default void requestChannel20_000() { @@ -157,14 +226,19 @@ default void requestChannel2_000_000() { @DisplayName("makes 1 requestChannel request with 3 payloads") @Test default void requestChannel3() { - Flux payloads = Flux.range(0, 3).map(this::createTestPayload); + AtomicLong requested = new AtomicLong(); + Flux payloads = + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); getClient() .requestChannel(payloads) - .as(StepVerifier::create) + .as(publisher -> StepVerifier.create(publisher, 3)) .expectNextCount(3) .expectComplete() .verify(getTimeout()); + + Assertions.assertThat(requested.get()) + .isEqualTo(256L); // 256 because of eager behavior of limitRate } @DisplayName("makes 1 requestChannel request with 512 payloads") @@ -223,6 +297,17 @@ default void requestResponse100() { .verify(getTimeout()); } + @DisplayName("makes 100 requestResponse requests") + @Test + default void largePayloadRequestResponse100() { + Flux.range(1, 100) + .flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD).map(Payload::getDataUtf8)) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 10,000 requestResponse requests") @Test default void requestResponse10_000() { @@ -283,7 +368,7 @@ default void assertPayload(Payload p) { } default void assertChannelPayload(Payload p) { - if (!"test-data".equals(p.getDataUtf8()) || !"metadata".equals(p.getMetadataUtf8())) { + if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) { throw new IllegalStateException("Unexpected payload"); } } @@ -304,16 +389,12 @@ public TransportPair( T address = addressSupplier.get(); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); client = - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java b/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java deleted file mode 100644 index ad45e106a..000000000 --- a/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; - -import io.rsocket.uri.UriHandler; -import java.net.URI; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -public interface UriHandlerTest { - - @DisplayName("returns empty Optional client with invalid URI") - @Test - default void buildClientInvalidUri() { - assertThat(getUriHandler().buildClient(URI.create(getInvalidUri()))).isEmpty(); - } - - @DisplayName("buildClient throws NullPointerException with null uri") - @Test - default void buildClientNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> getUriHandler().buildClient(null)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns client with value URI") - @Test - default void buildClientValidUri() { - assertThat(getUriHandler().buildClient(URI.create(getValidUri()))).isNotEmpty(); - } - - @DisplayName("returns empty Optional server with invalid URI") - @Test - default void buildServerInvalidUri() { - assertThat(getUriHandler().buildServer(URI.create(getInvalidUri()))).isEmpty(); - } - - @DisplayName("buildServer throws NullPointerException with null uri") - @Test - default void buildServerNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> getUriHandler().buildServer(null)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns server with value URI") - @Test - default void buildServerValidUri() { - assertThat(getUriHandler().buildServer(URI.create(getValidUri()))).isNotEmpty(); - } - - String getInvalidUri(); - - UriHandler getUriHandler(); - - String getValidUri(); -} diff --git a/rsocket-test/src/main/resources/words.shakespeare.txt.gz b/rsocket-test/src/main/resources/words.shakespeare.txt.gz new file mode 100644 index 000000000..422a4b331 Binary files /dev/null and b/rsocket-test/src/main/resources/words.shakespeare.txt.gz differ diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index 990acddfe..d69bd65e8 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -20,6 +20,7 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; @@ -36,21 +37,39 @@ public final class LocalClientTransport implements ClientTransport { private final String name; - private LocalClientTransport(String name) { + private final ByteBufAllocator allocator; + + private LocalClientTransport(String name, ByteBufAllocator allocator) { this.name = name; + this.allocator = allocator; } /** * Creates a new instance. * - * @param name the name of the {@link ServerTransport} instance to connect to + * @param name the name of the {@link ClientTransport} instance to connect to * @return a new instance * @throws NullPointerException if {@code name} is {@code null} */ public static LocalClientTransport create(String name) { Objects.requireNonNull(name, "name must not be null"); - return new LocalClientTransport(name); + return create(name, ByteBufAllocator.DEFAULT); + } + + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @param allocator the allocator used by {@link ClientTransport} instance + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalClientTransport create(String name, ByteBufAllocator allocator) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(allocator, "allocator must not be null"); + + return new LocalClientTransport(name, allocator); } private Mono connect() { @@ -65,9 +84,10 @@ private Mono connect() { UnboundedProcessor out = new UnboundedProcessor<>(); MonoProcessor closeNotifier = MonoProcessor.create(); - server.accept(new LocalDuplexConnection(out, in, closeNotifier)); + server.accept(new LocalDuplexConnection(allocator, out, in, closeNotifier)); - return Mono.just((DuplexConnection) new LocalDuplexConnection(in, out, closeNotifier)); + return Mono.just( + (DuplexConnection) new LocalDuplexConnection(allocator, in, out, closeNotifier)); }); } @@ -75,13 +95,14 @@ private Mono connect() { public Mono connect(int mtu) { Mono isError = FragmentationDuplexConnection.checkMtu(mtu); Mono connect = isError != null ? isError : connect(); - if (mtu > 0) { - return connect.map( - duplexConnection -> - new FragmentationDuplexConnection( - duplexConnection, ByteBufAllocator.DEFAULT, mtu, false, "client")); - } else { - return connect; - } + + return connect.map( + duplexConnection -> { + if (mtu > 0) { + return new FragmentationDuplexConnection(duplexConnection, mtu, false, "client"); + } else { + return new ReassemblyDuplexConnection(duplexConnection, false); + } + }); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java index f9501717c..afaa14f95 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.transport.local; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import java.util.Objects; import org.reactivestreams.Publisher; @@ -28,6 +29,7 @@ /** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ final class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; private final Flux in; private final MonoProcessor onClose; @@ -42,7 +44,12 @@ final class LocalDuplexConnection implements DuplexConnection { * @param onClose the closing notifier * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} */ - LocalDuplexConnection(Flux in, Subscriber out, MonoProcessor onClose) { + LocalDuplexConnection( + ByteBufAllocator allocator, + Flux in, + Subscriber out, + MonoProcessor onClose) { + this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); this.in = Objects.requireNonNull(in, "in must not be null"); this.out = Objects.requireNonNull(out, "out must not be null"); this.onClose = Objects.requireNonNull(onClose, "onClose must not be null"); @@ -82,4 +89,9 @@ public Mono sendOne(ByteBuf frame) { out.onNext(frame); return Mono.empty(); } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index d755859d2..382b4533a 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -16,10 +16,10 @@ package io.rsocket.transport.local; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.Objects; @@ -166,8 +166,9 @@ public void accept(DuplexConnection duplexConnection) { if (mtu > 0) { duplexConnection = - new FragmentationDuplexConnection( - duplexConnection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + new FragmentationDuplexConnection(duplexConnection, mtu, false, "server"); + } else { + duplexConnection = new ReassemblyDuplexConnection(duplexConnection, false); } acceptor.apply(duplexConnection).subscribe(); diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java deleted file mode 100644 index 89c816d7a..000000000 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - -/** - * An implementation of {@link UriHandler} that creates {@link LocalClientTransport}s and {@link - * LocalServerTransport}s. - */ -public final class LocalUriHandler implements UriHandler { - - private static final String SCHEME = "local"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(LocalClientTransport.create(uri.getSchemeSpecificPart())); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(LocalServerTransport.create(uri.getSchemeSpecificPart())); - } -} diff --git a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 6ff8ffb50..000000000 --- a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.local.LocalUriHandler diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java index 2e4f93ac4..9228e2d05 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ package io.rsocket.transport.local; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.test.PingHandler; @@ -28,18 +29,15 @@ public final class LocalPingPong { public static void main(String... args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(LocalServerTransport.create("test-local-server")) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("test-local-server")) .block(); Mono client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(LocalClientTransport.create("test-local-server")) - .start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("test-local-server")); PingClient pingClient = new PingClient(client); diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java deleted file mode 100644 index ed8e6cd1d..000000000 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class LocalUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "http://test"; - } - - @Override - public UriHandler getUriHandler() { - return new LocalUriHandler(); - } - - @Override - public String getValidUri() { - return "local:test"; - } -} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java deleted file mode 100644 index f6b5cda7e..000000000 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class LocalUriTransportRegistryTest { - - @DisplayName("local URI returns LocalClientTransport") - @Test - void clientForUri() { - assertThat(UriTransportRegistry.clientForUri("local:test1")) - .isInstanceOf(LocalClientTransport.class); - } - - @DisplayName("non-local URI does not return LocalClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("http://localhost")) - .isNotInstanceOf(LocalClientTransport.class); - } - - @DisplayName("local URI returns LocalServerTransport") - @Test - void serverForUri() { - assertThat(UriTransportRegistry.serverForUri("local:test1")) - .isInstanceOf(LocalServerTransport.class); - } - - @DisplayName("non-local URI does not return LocalServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("http://localhost")) - .isNotInstanceOf(LocalServerTransport.class); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index c9c29f0a9..d71d6b356 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -31,7 +31,6 @@ public final class TcpDuplexConnection extends BaseDuplexConnection { private final Connection connection; - private final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; private final boolean encodeLength; /** @@ -62,6 +61,11 @@ public TcpDuplexConnection(Connection connection, boolean encodeLength) { }); } + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + @Override protected void doOnClose() { if (!connection.isDisposed()) { @@ -84,7 +88,7 @@ public Mono send(Publisher frames) { private ByteBuf encode(ByteBuf frame) { if (encodeLength) { - return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); + return FrameLengthFlyweight.encode(alloc(), frame.readableBytes(), frame); } else { return frame; } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java deleted file mode 100644 index d4ebd57b7..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; -import reactor.netty.tcp.TcpServer; - -/** - * An implementation of {@link UriHandler} that creates {@link TcpClientTransport}s and {@link - * TcpServerTransport}s. - */ -public final class TcpUriHandler implements UriHandler { - - private static final String SCHEME = "tcp"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(TcpClientTransport.create(uri.getHost(), uri.getPort())); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of( - TcpServerTransport.create(TcpServer.create().host(uri.getHost()).port(uri.getPort()))); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index ead297928..0183ef19d 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -16,6 +16,7 @@ package io.rsocket.transport.netty; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.rsocket.DuplexConnection; import io.rsocket.internal.BaseDuplexConnection; @@ -53,6 +54,11 @@ public WebsocketDuplexConnection(Connection connection) { }); } + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + @Override protected void doOnClose() { if (!connection.isDisposed()) { diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java deleted file mode 100644 index 6438c4e28..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static io.rsocket.transport.netty.UriUtils.getPort; -import static io.rsocket.transport.netty.UriUtils.isSecure; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** - * An implementation of {@link UriHandler} that creates {@link WebsocketClientTransport}s and {@link - * WebsocketServerTransport}s. - */ -public final class WebsocketUriHandler implements UriHandler { - - private static final List SCHEME = Arrays.asList("ws", "wss", "http", "https"); - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (SCHEME.stream().noneMatch(scheme -> scheme.equals(uri.getScheme()))) { - return Optional.empty(); - } - - return Optional.of(WebsocketClientTransport.create(uri)); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (SCHEME.stream().noneMatch(scheme -> scheme.equals(uri.getScheme()))) { - return Optional.empty(); - } - - int port = isSecure(uri) ? getPort(uri, 443) : getPort(uri, 80); - - return Optional.of(WebsocketServerTransport.create(uri.getHost(), port)); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index f5e79e9bf..8be019f1c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -16,9 +16,9 @@ package io.rsocket.transport.netty.client; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -104,13 +104,9 @@ public Mono connect(int mtu) { c -> { if (mtu > 0) { return new FragmentationDuplexConnection( - new TcpDuplexConnection(c, false), - ByteBufAllocator.DEFAULT, - mtu, - true, - "client"); + new TcpDuplexConnection(c, false), mtu, true, "client"); } else { - return new TcpDuplexConnection(c); + return new ReassemblyDuplexConnection(new TcpDuplexConnection(c), false); } }); } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 5049119a5..b19621d46 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -20,9 +20,9 @@ import static io.rsocket.transport.netty.UriUtils.getPort; import static io.rsocket.transport.netty.UriUtils.isSecure; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; @@ -43,7 +43,6 @@ */ public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware { - private static final int DEFAULT_FRAME_SIZE = 65536; private static final String DEFAULT_PATH = "/"; private final HttpClient client; @@ -164,8 +163,9 @@ public Mono connect(int mtu) { DuplexConnection connection = new WebsocketDuplexConnection(c); if (mtu > 0) { connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "client"); + new FragmentationDuplexConnection(connection, mtu, false, "client"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); } return connection; }); diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index 54ef016c0..56dd59d45 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -16,9 +16,9 @@ package io.rsocket.transport.netty.server; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -105,13 +105,9 @@ public Mono start(ConnectionAcceptor acceptor, int mtu) { if (mtu > 0) { connection = new FragmentationDuplexConnection( - new TcpDuplexConnection(c, false), - ByteBufAllocator.DEFAULT, - mtu, - true, - "server"); + new TcpDuplexConnection(c, false), mtu, true, "server"); } else { - connection = new TcpDuplexConnection(c); + connection = new ReassemblyDuplexConnection(new TcpDuplexConnection(c), false); } acceptor .apply(connection) diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index 30aa0fa96..83cb010b7 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -18,27 +18,21 @@ import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; -import io.netty.buffer.ByteBufAllocator; -import io.netty.handler.codec.http.HttpMethod; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Consumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServerRoutes; +import reactor.netty.http.server.WebsocketServerSpec; import reactor.netty.http.websocket.WebsocketInbound; import reactor.netty.http.websocket.WebsocketOutbound; @@ -48,7 +42,7 @@ */ public final class WebsocketRouteTransport extends BaseWebsocketServerTransport { - private final UriPathTemplate template; + private final String path; private final Consumer routesBuilder; @@ -65,7 +59,7 @@ public WebsocketRouteTransport( HttpServer server, Consumer routesBuilder, String path) { this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null"); - this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null")); + this.path = Objects.requireNonNull(path, "path must not be null"); } @Override @@ -77,10 +71,9 @@ public Mono start(ConnectionAcceptor acceptor, int mtu) { routes -> { routesBuilder.accept(routes); routes.ws( - hsr -> hsr.method().equals(HttpMethod.GET) && template.matches(hsr.uri()), + path, newHandler(acceptor, mtu), - null, - FRAME_LENGTH_MASK); + WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()); }) .bind() .map(CloseableChannel::new); @@ -111,128 +104,11 @@ public static BiFunction> n return (in, out) -> { DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); if (mtu > 0) { - connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + connection = new FragmentationDuplexConnection(connection, mtu, false, "server"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); } return acceptor.apply(connection).then(out.neverComplete()); }; } - - static final class UriPathTemplate { - - private static final Pattern FULL_SPLAT_PATTERN = Pattern.compile("[\\*][\\*]"); - private static final String FULL_SPLAT_REPLACEMENT = ".*"; - - private static final Pattern NAME_SPLAT_PATTERN = Pattern.compile("\\{([^/]+?)\\}[\\*][\\*]"); - private static final String NAME_SPLAT_REPLACEMENT = "(?<%NAME%>.*)"; - - private static final Pattern NAME_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); - private static final String NAME_REPLACEMENT = "(?<%NAME%>[^\\/]*)"; - - private final List pathVariables = new ArrayList<>(); - private final HashMap matchers = new HashMap<>(); - private final HashMap> vars = new HashMap<>(); - - private final Pattern uriPattern; - - static String filterQueryParams(String uri) { - int hasQuery = uri.lastIndexOf("?"); - if (hasQuery != -1) { - return uri.substring(0, hasQuery); - } else { - return uri; - } - } - - /** - * Creates a new {@code UriPathTemplate} from the given {@code uriPattern}. - * - * @param uriPattern The pattern to be used by the template - */ - UriPathTemplate(String uriPattern) { - String s = "^" + filterQueryParams(uriPattern); - - Matcher m = NAME_SPLAT_PATTERN.matcher(s); - while (m.find()) { - for (int i = 1; i <= m.groupCount(); i++) { - String name = m.group(i); - pathVariables.add(name); - s = m.replaceFirst(NAME_SPLAT_REPLACEMENT.replaceAll("%NAME%", name)); - m.reset(s); - } - } - - m = NAME_PATTERN.matcher(s); - while (m.find()) { - for (int i = 1; i <= m.groupCount(); i++) { - String name = m.group(i); - pathVariables.add(name); - s = m.replaceFirst(NAME_REPLACEMENT.replaceAll("%NAME%", name)); - m.reset(s); - } - } - - m = FULL_SPLAT_PATTERN.matcher(s); - while (m.find()) { - s = m.replaceAll(FULL_SPLAT_REPLACEMENT); - m.reset(s); - } - - this.uriPattern = Pattern.compile(s + "$"); - } - - /** - * Tests the given {@code uri} against this template, returning {@code true} if the uri matches - * the template, {@code false} otherwise. - * - * @param uri The uri to match - * @return {@code true} if there's a match, {@code false} otherwise - */ - public boolean matches(String uri) { - return matcher(uri).matches(); - } - - /** - * Matches the template against the given {@code uri} returning a map of path parameters - * extracted from the uri, keyed by the names in the template. If the uri does not match, or - * there are no path parameters, an empty map is returned. - * - * @param uri The uri to match - * @return the path parameters from the uri. Never {@code null}. - */ - final Map match(String uri) { - Map pathParameters = vars.get(uri); - if (null != pathParameters) { - return pathParameters; - } - - pathParameters = new HashMap<>(); - Matcher m = matcher(uri); - if (m.matches()) { - int i = 1; - for (String name : pathVariables) { - String val = m.group(i++); - pathParameters.put(name, val); - } - } - synchronized (vars) { - vars.put(uri, pathParameters); - } - - return pathParameters; - } - - private Matcher matcher(String uri) { - uri = filterQueryParams(uri); - Matcher m = matchers.get(uri); - if (null == m) { - m = uriPattern.matcher(uri); - synchronized (matchers) { - matchers.put(uri, m); - } - } - return m; - } - } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 948d6f573..4a0331c08 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -18,9 +18,9 @@ import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; @@ -35,6 +35,7 @@ import reactor.core.publisher.Mono; import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.WebsocketServerSpec; /** * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a @@ -122,18 +123,20 @@ public Mono start(ConnectionAcceptor acceptor, int mtu) { (request, response) -> { transportHeaders.get().forEach(response::addHeader); return response.sendWebsocket( - null, - FRAME_LENGTH_MASK, (in, out) -> { DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); if (mtu > 0) { connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + new FragmentationDuplexConnection(connection, mtu, false, "server"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); } return acceptor.apply(connection).then(out.neverComplete()); - }); + }, + WebsocketServerSpec.builder() + .maxFramePayloadLength(FRAME_LENGTH_MASK) + .build()); }) .bind() .map(CloseableChannel::new); diff --git a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index ec7ddcb80..000000000 --- a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.netty.TcpUriHandler -io.rsocket.transport.netty.WebsocketUriHandler diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java index 575993c18..0ea938af2 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,29 +21,34 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import io.rsocket.util.RSocketProxy; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class FragmentTest { - private static final int frameSize = 64; private AbstractRSocket handler; private CloseableChannel server; private String message = null; private String metaData = null; private String responseMessage = null; - @BeforeEach - public void startup() { + private static Stream cases() { + return Stream.of(Arguments.of(0, 64), Arguments.of(64, 0), Arguments.of(64, 64)); + } + + public void startup(int frameSize) { int randomPort = ThreadLocalRandom.current().nextInt(10_000, 20_000); StringBuilder message = new StringBuilder(); StringBuilder responseMessage = new StringBuilder(); @@ -59,19 +64,16 @@ public void startup() { TcpServerTransport serverTransport = TcpServerTransport.create("localhost", randomPort); server = - RSocketFactory.receive() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) .fragment(frameSize) - .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) - .transport(serverTransport) - .start() + .bind(serverTransport) .block(); } - private RSocket buildClient() { - return RSocketFactory.connect() + private RSocket buildClient(int frameSize) { + return RSocketConnector.create() .fragment(frameSize) - .transport(TcpClientTransport.create(server.address())) - .start() + .connect(TcpClientTransport.create(server.address())) .block(); } @@ -80,8 +82,10 @@ public void cleanup() { server.dispose(); } - @Test - void testFragmentNoMetaData() { + @ParameterizedTest + @MethodSource("cases") + void testFragmentNoMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); System.out.println( "-------------------------------------------------testFragmentNoMetaData-------------------------------------------------"); handler = @@ -97,7 +101,7 @@ public Flux requestStream(Payload payload) { } }; - RSocket client = buildClient(); + RSocket client = buildClient(clientFrameSize); System.out.println("original message: " + message); System.out.println("original metadata: " + metaData); @@ -108,8 +112,10 @@ public Flux requestStream(Payload payload) { assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); } - @Test - void testFragmentRequestMetaDataOnly() { + @ParameterizedTest + @MethodSource("cases") + void testFragmentRequestMetaDataOnly(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); System.out.println( "-------------------------------------------------testFragmentRequestMetaDataOnly-------------------------------------------------"); handler = @@ -125,7 +131,7 @@ public Flux requestStream(Payload payload) { } }; - RSocket client = buildClient(); + RSocket client = buildClient(clientFrameSize); System.out.println("original message: " + message); System.out.println("original metadata: " + metaData); @@ -136,8 +142,10 @@ public Flux requestStream(Payload payload) { assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); } - @Test - void testFragmentBothMetaData() { + @ParameterizedTest + @MethodSource("cases") + void testFragmentBothMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); Payload responsePayload = DefaultPayload.create(responseMessage); System.out.println( "-------------------------------------------------testFragmentBothMetaData-------------------------------------------------"); @@ -164,7 +172,7 @@ public Mono requestResponse(Payload payload) { } }; - RSocket client = buildClient(); + RSocket client = buildClient(clientFrameSize); System.out.println("original message: " + message); System.out.println("original metadata: " + metaData); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java index 07e9378fa..b9c0d4f60 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java @@ -1,18 +1,15 @@ package io.rsocket.transport.netty; -import static io.rsocket.RSocketFactory.*; - import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; -import io.rsocket.transport.ClientTransport; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.transport.netty.server.WebsocketServerTransport; import java.time.Duration; -import java.util.function.Function; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -22,104 +19,62 @@ class RSocketFactoryNettyTransportFragmentationTest { - @ParameterizedTest - @MethodSource("serverTransportProvider") - void serverErrorsWithEnabledFragmentationOnInsufficientMtu( - ServerTransport serverTransport) { - Mono server = createServer(serverTransport, f -> f.fragment(2)); - - StepVerifier.create(server) - .expectErrorMatches( - err -> - err instanceof IllegalArgumentException - && "smallest allowed mtu size is 64 bytes, provided: 2" - .equals(err.getMessage())) - .verify(Duration.ofSeconds(5)); + static Stream> arguments() { + return Stream.of(TcpServerTransport.create(0), WebsocketServerTransport.create(0)); } @ParameterizedTest - @MethodSource("serverTransportProvider") + @MethodSource("arguments") void serverSucceedsWithEnabledFragmentationOnSufficientMtu( ServerTransport serverTransport) { Mono server = - createServer(serverTransport, f -> f.fragment(100)).doOnNext(CloseableChannel::dispose); + RSocketServer.create(mockAcceptor()) + .fragment(100) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } @ParameterizedTest - @MethodSource("serverTransportProvider") - void serverSucceedsWithDisabledFragmentation() { + @MethodSource("arguments") + void serverSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { Mono server = - createServer(TcpServerTransport.create("localhost", 0), Function.identity()) + RSocketServer.create(mockAcceptor()) + .bind(serverTransport) .doOnNext(CloseableChannel::dispose); StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } @ParameterizedTest - @MethodSource("serverTransportProvider") - void clientErrorsWithEnabledFragmentationOnInsufficientMtu( - ServerTransport serverTransport) { - CloseableChannel server = createServer(serverTransport, f -> f.fragment(100)).block(); - - Mono rSocket = - createClient(TcpClientTransport.create(server.address()), f -> f.fragment(2)) - .doFinally(s -> server.dispose()); - - StepVerifier.create(rSocket) - .expectErrorMatches( - err -> - err instanceof IllegalArgumentException - && "smallest allowed mtu size is 64 bytes, provided: 2" - .equals(err.getMessage())) - .verify(Duration.ofSeconds(5)); - } - - @ParameterizedTest - @MethodSource("serverTransportProvider") + @MethodSource("arguments") void clientSucceedsWithEnabledFragmentationOnSufficientMtu( ServerTransport serverTransport) { - CloseableChannel server = createServer(serverTransport, f -> f.fragment(100)).block(); + CloseableChannel server = + RSocketServer.create(mockAcceptor()).fragment(100).bind(serverTransport).block(); Mono rSocket = - createClient(TcpClientTransport.create(server.address()), f -> f.fragment(100)) + RSocketConnector.create() + .fragment(100) + .connect(TcpClientTransport.create(server.address())) .doFinally(s -> server.dispose()); StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } @ParameterizedTest - @MethodSource("serverTransportProvider") - void clientSucceedsWithDisabledFragmentation() { - CloseableChannel server = - createServer(TcpServerTransport.create("localhost", 0), Function.identity()).block(); + @MethodSource("arguments") + void clientSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + CloseableChannel server = RSocketServer.create(mockAcceptor()).bind(serverTransport).block(); Mono rSocket = - createClient(TcpClientTransport.create(server.address()), Function.identity()) + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) .doFinally(s -> server.dispose()); StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } - private Mono createClient( - ClientTransport transport, Function f) { - return f.apply(RSocketFactory.connect()).transport(transport).start(); - } - - private Mono createServer( - ServerTransport transport, - Function f) { - return f.apply(receive()).acceptor(mockAcceptor()).transport(transport).start(); - } - private SocketAcceptor mockAcceptor() { SocketAcceptor mock = Mockito.mock(SocketAcceptor.class); Mockito.when(mock.accept(Mockito.any(), Mockito.any())) .thenReturn(Mono.just(Mockito.mock(RSocket.class))); return mock; } - - static Stream> serverTransportProvider() { - String host = "localhost"; - int port = 0; - return Stream.of( - TcpServerTransport.create(host, port), WebsocketServerTransport.create(host, port)); - } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java index f32d28a0b..6fd3de791 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -2,8 +2,9 @@ import io.rsocket.ConnectionSetupPayload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; @@ -41,19 +42,15 @@ void rejectSetupTcp( Mono serverRequester = acceptor.requesterRSocket(); CloseableChannel channel = - RSocketFactory.receive() - .acceptor(acceptor) - .transport(serverTransport.apply(new InetSocketAddress("localhost", 0))) - .start() + RSocketServer.create(acceptor) + .bind(serverTransport.apply(new InetSocketAddress("localhost", 0))) .block(Duration.ofSeconds(5)); ErrorConsumer errorConsumer = new ErrorConsumer(); RSocket clientRequester = - RSocketFactory.connect() - .errorConsumer(errorConsumer) - .transport(clientTransport.apply(channel.address())) - .start() + RSocketConnector.connectWith(clientTransport.apply(channel.address())) + .doOnError(errorConsumer) .block(Duration.ofSeconds(5)); StepVerifier.create(errorConsumer.errors().next()) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java index c2e136635..88c64648c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java @@ -17,7 +17,8 @@ package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PerfTest; import io.rsocket.test.PingClient; @@ -81,16 +82,15 @@ private static PingClient newResumablePingClient() { } private static PingClient newPingClient(boolean isResumable) { - RSocketFactory.ClientRSocketFactory clientRSocketFactory = RSocketFactory.connect(); + RSocketConnector connector = RSocketConnector.create(); if (isResumable) { - clientRSocketFactory.resume(); + connector.resume(new Resume()); } Mono rSocket = - clientRSocketFactory - .frameDecoder(PayloadDecoder.ZERO_COPY) - .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30), 3) - .transport(TcpClientTransport.create(port)) - .start(); + connector + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30)) + .connect(TcpClientTransport.create(port)); return new PingClient(rSocket); } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java index b40f35e51..338868470 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java @@ -16,7 +16,8 @@ package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -31,15 +32,13 @@ public static void main(String... args) { System.out.println("port: " + port); System.out.println("resume enabled: " + isResume); - RSocketFactory.ServerRSocketFactory serverRSocketFactory = RSocketFactory.receive(); + RSocketServer server = RSocketServer.create(new PingHandler()); if (isResume) { - serverRSocketFactory.resume(); + server.resume(new Resume()); } - serverRSocketFactory - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create("localhost", port)) - .start() + server + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create("localhost", port)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java new file mode 100644 index 000000000..b77de6d4e --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -0,0 +1,55 @@ +package io.rsocket.transport.netty; + +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import reactor.core.Exceptions; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +public class TcpSecureTransportTest implements TransportTest { + private final TransportPair transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server) -> + TcpClientTransport.create( + TcpClient.create() + .addressSupplier(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + address -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .addressSupplier(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(10); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java deleted file mode 100644 index 25b443dd6..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class TcpUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "http://test"; - } - - @Override - public UriHandler getUriHandler() { - return new TcpUriHandler(); - } - - @Override - public String getValidUri() { - return "tcp://test:9898"; - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java deleted file mode 100644 index a71cc27f9..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class TcpUriTransportRegistryTest { - - @DisplayName("non-tcp URI does not return TcpClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("amqp://localhost")) - .isNotInstanceOf(TcpClientTransport.class) - .isNotInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("tcp URI returns TcpClientTransport") - @Test - void clientForUriTcp() { - assertThat(UriTransportRegistry.clientForUri("tcp://test:9898")) - .isInstanceOf(TcpClientTransport.class); - } - - @DisplayName("non-tcp URI does not return TcpServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("amqp://localhost")) - .isNotInstanceOf(TcpServerTransport.class) - .isNotInstanceOf(WebsocketServerTransport.class); - } - - @DisplayName("tcp URI returns TcpServerTransport") - @Test - void serverForUriTcp() { - assertThat(UriTransportRegistry.serverForUri("tcp://test:9898")) - .isInstanceOf(TcpServerTransport.class); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java index 4fe40d232..7028a3846 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java @@ -3,7 +3,8 @@ import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.server.WebsocketRouteTransport; @@ -23,8 +24,7 @@ public class WebSocketTransportIntegrationTest { @Test public void sendStreamOfDataWithExternalHttpServerTest() { ServerTransport.ConnectionAcceptor acceptor = - RSocketFactory.receive() - .acceptor( + RSocketServer.create( (setupPayload, sendingRSocket) -> { return Mono.just( new AbstractRSocket() { @@ -35,7 +35,7 @@ public Flux requestStream(Payload payload) { } }); }) - .toConnectionAcceptor(); + .asConnectionAcceptor(); DisposableServer server = HttpServer.create() @@ -44,11 +44,9 @@ public Flux requestStream(Payload payload) { .bindNow(); RSocket rsocket = - RSocketFactory.connect() - .transport( + RSocketConnector.connectWith( WebsocketClientTransport.create( URI.create("ws://" + server.host() + ":" + server.port() + "/test"))) - .start() .block(); StepVerifier.create(rsocket.requestStream(EmptyPayload.INSTANCE)) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java index 306be4e43..a784a43c0 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.transport.netty.client.WebsocketClientTransport; @@ -29,10 +29,9 @@ public final class WebsocketPing { public static void main(String... args) { Mono client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(WebsocketClientTransport.create(7878)) - .start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(7878)); PingClient pingClient = new PingClient(client); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java index eac091dd8..ab6c343de 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -8,7 +8,12 @@ import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCountUtil; -import io.rsocket.*; +import io.rsocket.AbstractRSocket; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.server.WebsocketRouteTransport; @@ -42,10 +47,8 @@ void tearDown() { @MethodSource("provideServerTransport") void webSocketPingPong(ServerTransport serverTransport) { server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new EchoRSocket())) - .transport(serverTransport) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new EchoRSocket())) + .bind(serverTransport) .block(); String expectedData = "data"; @@ -63,10 +66,7 @@ void webSocketPingPong(ServerTransport serverTransport) { .port(port)); RSocket rSocket = - RSocketFactory.connect() - .transport(WebsocketClientTransport.create(httpClient, "/")) - .start() - .block(); + RSocketConnector.connectWith(WebsocketClientTransport.create(httpClient, "/")).block(); rSocket .requestResponse(DefaultPayload.create(expectedData)) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java index 7fdb1813a..84dc816be 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.WebsocketServerTransport; @@ -24,11 +24,9 @@ public final class WebsocketPongServer { public static void main(String... args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(WebsocketServerTransport.create(7878)) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(WebsocketServerTransport.create(7878)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java deleted file mode 100644 index 72a700b0e..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class WebsocketUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "amqp://test"; - } - - @Override - public UriHandler getUriHandler() { - return new WebsocketUriHandler(); - } - - @Override - public String getValidUri() { - return "ws://test:9898"; - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java deleted file mode 100644 index 5688f14ed..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class WebsocketUriTransportRegistryTest { - - @DisplayName("non-ws URI does not return WebsocketClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("amqp://localhost")) - .isNotInstanceOf(TcpClientTransport.class) - .isNotInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("ws URI returns WebsocketClientTransport") - @Test - void clientForUriWebsocket() { - assertThat(UriTransportRegistry.clientForUri("ws://test:9898")) - .isInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("non-ws URI does not return WebsocketServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("amqp://localhost")) - .isNotInstanceOf(TcpServerTransport.class) - .isNotInstanceOf(WebsocketServerTransport.class); - } - - @DisplayName("ws URI returns WebsocketServerTransport") - @Test - void serverForUriWebsocket() { - assertThat(UriTransportRegistry.serverForUri("ws://test:9898")) - .isInstanceOf(WebsocketServerTransport.class); - } -}