diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 21112c24d..fefb06003 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -54,7 +54,7 @@ import io.rsocket.util.MonoLifecycleHandler; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CancellationException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; import java.util.function.LongConsumer; @@ -260,6 +260,7 @@ public void doOnTerminal( removeStreamReceiver(streamId); } }); + receivers.put(streamId, receiver); return receiver.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); @@ -281,7 +282,7 @@ private Flux handleRequestStream(final Payload payload) { final UnboundedProcessor sendProcessor = this.sendProcessor; final UnicastProcessor receiver = UnicastProcessor.create(); - final AtomicBoolean payloadReleasedFlag = new AtomicBoolean(false); + final AtomicInteger wip = new AtomicInteger(0); receivers.put(streamId, receiver); @@ -293,30 +294,49 @@ private Flux handleRequestStream(final Payload payload) { @Override public void accept(long n) { - if (firstRequest && !receiver.isDisposed()) { + if (firstRequest) { firstRequest = false; - if (!payloadReleasedFlag.getAndSet(true)) { - sendProcessor.onNext( - RequestStreamFrameFlyweight.encodeReleasingPayload( - allocator, streamId, n, payload)); + if (wip.getAndIncrement() != 0) { + // no need to do anything. + // stream was canceled and fist payload has already been discarded + return; } - } else if (contains(streamId) && !receiver.isDisposed()) { + int missed = 1; + boolean firstHasBeenSent = false; + for (; ; ) { + if (!firstHasBeenSent) { + sendProcessor.onNext( + RequestStreamFrameFlyweight.encodeReleasingPayload( + allocator, streamId, n, payload)); + firstHasBeenSent = true; + } else { + // if first frame was sent but we cycling again, it means that wip was + // incremented at doOnCancel + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + return; + } + + missed = wip.addAndGet(-missed); + if (missed == 0) { + return; + } + } + } else if (!receiver.isDisposed()) { sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); } } }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - }) .doOnCancel( () -> { - if (!payloadReleasedFlag.getAndSet(true)) { - payload.release(); + if (wip.getAndIncrement() != 0) { + return; } - if (contains(streamId) && !receiver.isDisposed()) { + + // check if we need to release payload + // only applicable if the cancel appears earlier than actual request + if (payload.refCnt() > 0) { + payload.release(); + } else { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } }) @@ -330,30 +350,32 @@ private Flux handleChannel(Flux request) { return Flux.error(err); } - return request.switchOnFirst( - (s, flux) -> { - Payload payload = s.get(); - if (payload != null) { - if (!PayloadValidationUtils.isValid(mtu, payload)) { - payload.release(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - errorConsumer.accept(t); - return Mono.error(t); - } - return handleChannel(payload, flux); - } else { - return flux; - } - }, - false); + return request + .switchOnFirst( + (s, flux) -> { + Payload payload = s.get(); + if (payload != null) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + errorConsumer.accept(t); + return Mono.error(t); + } + return handleChannel(payload, flux); + } else { + return flux; + } + }, + false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { final UnboundedProcessor sendProcessor = this.sendProcessor; - final AtomicBoolean payloadReleasedFlag = new AtomicBoolean(false); final int streamId = streamIdSupplier.nextStreamId(receivers); + final AtomicInteger wip = new AtomicInteger(0); final UnicastProcessor receiver = UnicastProcessor.create(); final BaseSubscriber upstreamSubscriber = new BaseSubscriber() { @@ -421,19 +443,47 @@ protected void hookFinally(SignalType type) { public void accept(long n) { if (firstRequest) { firstRequest = false; - senders.put(streamId, upstreamSubscriber); - receivers.put(streamId, receiver); - - inboundFlux - .limitRate(Queues.SMALL_BUFFER_SIZE) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) - .subscribe(upstreamSubscriber); - if (!payloadReleasedFlag.getAndSet(true)) { - ByteBuf frame = - RequestChannelFrameFlyweight.encodeReleasingPayload( - allocator, streamId, false, n, initialPayload); - - sendProcessor.onNext(frame); + if (wip.getAndIncrement() != 0) { + // no need to do anything. + // stream was canceled and fist payload has already been discarded + return; + } + int missed = 1; + boolean firstHasBeenSent = false; + for (; ; ) { + if (!firstHasBeenSent) { + ByteBuf frame; + try { + frame = + RequestChannelFrameFlyweight.encodeReleasingPayload( + allocator, streamId, false, n, initialPayload); + } catch (IllegalReferenceCountException e) { + return; + } + + senders.put(streamId, upstreamSubscriber); + receivers.put(streamId, receiver); + + inboundFlux + .limitRate(Queues.SMALL_BUFFER_SIZE) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); + + sendProcessor.onNext(frame); + firstHasBeenSent = true; + } else { + // if first frame was sent but we cycling again, it means that wip was + // incremented at doOnCancel + senders.remove(streamId, upstreamSubscriber); + receivers.remove(streamId, receiver); + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + return; + } + + missed = wip.addAndGet(-missed); + if (missed == 0) { + return; + } } } else { sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); @@ -442,22 +492,22 @@ public void accept(long n) { }) .doOnError( t -> { - if (receivers.remove(streamId, receiver)) { - upstreamSubscriber.cancel(); - } + upstreamSubscriber.cancel(); + receivers.remove(streamId, receiver); }) .doOnComplete(() -> receivers.remove(streamId, receiver)) .doOnCancel( () -> { - if (!payloadReleasedFlag.getAndSet(true)) { - initialPayload.release(); + upstreamSubscriber.cancel(); + if (wip.getAndIncrement() != 0) { + return; } + + // need to send frame only if RequestChannelFrame was sent if (receivers.remove(streamId, receiver)) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - upstreamSubscriber.cancel(); } - }) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + }); } private Mono handleMetadataPush(Payload payload) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java index ea54aa374..73bfd38f1 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java @@ -37,8 +37,34 @@ static ByteBuf encode( boolean hasMetadata, ByteBuf data) { - final boolean addData = data != null && data.isReadable(); - final boolean addMetadata = hasMetadata && metadata.isReadable(); + final boolean addData; + if (data != null) { + if (data.isReadable()) { + addData = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + data.release(); + addData = false; + } + } else { + addData = false; + } + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } if (hasMetadata) { int length = metadata.readableBytes(); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java index 039c72886..32f086a15 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java @@ -14,7 +14,6 @@ public static ByteBuf encode( @Nullable final ByteBuf metadata) { final boolean hasMetadata = metadata != null; - final boolean addMetadata = hasMetadata && metadata.isReadable(); int flags = 0; @@ -27,6 +26,21 @@ public static ByteBuf encode( .writeInt(ttl) .writeInt(numRequests); + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + if (addMetadata) { return allocator.compositeBuffer(2).addComponents(true, header, metadata); } else { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java index e3a9a47ba..a39acef92 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java @@ -2,13 +2,21 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class MetadataPushFrameFlyweight { public static ByteBuf encodeReleasingPayload(ByteBufAllocator allocator, Payload payload) { final ByteBuf metadata = payload.metadata().retain(); - payload.release(); + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + metadata.release(); + throw e; + } return encode(allocator, metadata); } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java index 4c2ebdf6e..53ac6150b 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class PayloadFrameFlyweight { @@ -23,11 +24,31 @@ public static ByteBuf encodeNextCompleteReleasingPayload( static ByteBuf encodeReleasingPayload( ByteBufAllocator allocator, int streamId, boolean complete, Payload payload) { - final boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; - final ByteBuf data = payload.data().retain(); - - payload.release(); + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } return encode(allocator, streamId, false, complete, true, metadata, data); } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java index 7c3cbb574..c0db21170 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestChannelFrameFlyweight { @@ -17,11 +18,31 @@ public static ByteBuf encodeReleasingPayload( long initialRequestN, Payload payload) { - final boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; - final ByteBuf data = payload.data().retain(); - - payload.release(); + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } return encode(allocator, streamId, false, complete, initialRequestN, metadata, data); } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java index 287f765f7..e091edcc3 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestFireAndForgetFrameFlyweight { @@ -13,11 +14,31 @@ private RequestFireAndForgetFrameFlyweight() {} public static ByteBuf encodeReleasingPayload( ByteBufAllocator allocator, int streamId, Payload payload) { - final boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; - final ByteBuf data = payload.data().retain(); - - payload.release(); + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } return FLYWEIGHT.encode(allocator, streamId, false, metadata, data); } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java index 3fbac27d2..782c70965 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestResponseFrameFlyweight { @@ -13,11 +14,31 @@ private RequestResponseFrameFlyweight() {} public static ByteBuf encodeReleasingPayload( ByteBufAllocator allocator, int streamId, Payload payload) { - final boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; - final ByteBuf data = payload.data().retain(); - - payload.release(); + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } return encode(allocator, streamId, false, metadata, data); } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java index ff1435652..2fb209ffb 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.rsocket.Payload; public class RequestStreamFrameFlyweight { @@ -13,11 +14,31 @@ private RequestStreamFrameFlyweight() {} public static ByteBuf encodeReleasingPayload( ByteBufAllocator allocator, int streamId, long initialRequestN, Payload payload) { - final boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; - final ByteBuf data = payload.data().retain(); - - payload.release(); + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } return encode(allocator, streamId, false, initialRequestN, metadata, data); } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java index af4c8768b..c5b06e086 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java @@ -99,6 +99,7 @@ public static UnicastMonoProcessor create(MonoLifecycleHandler lifecyc UnicastMonoProcessor.class, Subscription.class, "subscription"); CoreSubscriber actual; + boolean hasDownstream = false; Throwable error; O value; @@ -185,7 +186,7 @@ private void complete(O v) { if (state == HAS_REQUEST_NO_RESULT) { if (STATE.compareAndSet(this, HAS_REQUEST_NO_RESULT, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; - actual = null; + hasDownstream = false; value = null; lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, v, null); a.onNext(v); @@ -222,7 +223,7 @@ private void complete() { if (state == HAS_REQUEST_NO_RESULT || state == NO_REQUEST_NO_RESULT) { if (STATE.compareAndSet(this, state, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; - actual = null; + hasDownstream = false; lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, null, null); a.onComplete(); return; @@ -256,7 +257,7 @@ private void complete(Throwable e) { if (state == HAS_REQUEST_NO_RESULT || state == NO_REQUEST_NO_RESULT) { if (STATE.compareAndSet(this, state, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; - actual = null; + hasDownstream = false; lifecycleHandler.doOnTerminal(SignalType.ON_ERROR, null, e); a.onError(e); return; @@ -278,6 +279,7 @@ public void subscribe(CoreSubscriber actual) { lh.doOnSubscribe(); + this.hasDownstream = true; this.actual = actual; int state = this.state; @@ -303,7 +305,7 @@ public void subscribe(CoreSubscriber actual) { // no value // e.g. [onError / onComplete / dispose] only if (state == NO_REQUEST_HAS_RESULT && this.value == null) { - this.actual = null; + this.hasDownstream = false; Throwable e = this.error; // barrier to flush changes STATE.set(this, HAS_REQUEST_HAS_RESULT); @@ -340,7 +342,7 @@ public final void request(long n) { if (STATE.compareAndSet(this, NO_REQUEST_HAS_RESULT, HAS_REQUEST_HAS_RESULT)) { final Subscriber a = actual; final O v = value; - actual = null; + hasDownstream = false; value = null; lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, v, null); a.onNext(v); @@ -360,7 +362,7 @@ public final void cancel() { if (STATE.getAndSet(this, CANCELLED) <= HAS_REQUEST_NO_RESULT) { Operators.onDiscard(value, currentContext()); value = null; - actual = null; + hasDownstream = false; lifecycleHandler.doOnTerminal(SignalType.CANCEL, null, null); final Subscription s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription()); if (s != null && s != Operators.cancelledSubscription()) { @@ -502,6 +504,6 @@ public Object scanUnsafe(Attr key) { * @return true if any {@link Subscriber} is actively subscribed */ public final boolean hasDownstream() { - return state > NO_SUBSCRIBER_HAS_RESULT && actual != null; + return state > NO_SUBSCRIBER_HAS_RESULT && hasDownstream; } } diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java index 2044779ef..800e5d678 100644 --- a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -3,7 +3,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; -import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import org.assertj.core.api.Assertions; @@ -35,22 +34,9 @@ public LeaksTrackingByteBufAllocator assertHasNoLeaks() { try { Assertions.assertThat(tracker) .allSatisfy( - buf -> { - if (buf instanceof CompositeByteBuf) { - if (buf.refCnt() > 0) { - List decomposed = - ((CompositeByteBuf) buf).decompose(0, buf.readableBytes()); - for (int i = 0; i < decomposed.size(); i++) { - Assertions.assertThat(decomposed.get(i)) - .matches(bb -> bb.refCnt() == 0, "Got unreleased CompositeByteBuf"); - } - } - - } else { + buf -> Assertions.assertThat(buf) - .matches(bb -> bb.refCnt() == 0, "buffer should be released"); - } - }); + .matches(bb -> bb.refCnt() == 0, "buffer should be released")); } finally { tracker.clear(); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index e536d2db4..3b62bc437 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -38,6 +38,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -71,6 +72,7 @@ import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -84,6 +86,7 @@ import org.reactivestreams.Subscription; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; @@ -97,6 +100,8 @@ public class RSocketRequesterTest { @BeforeEach public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); rule = new ClientSocketRule(); rule.apply( new Statement() { @@ -107,6 +112,12 @@ public void evaluate() {} .evaluate(); } + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + @Test @Timeout(2_000) public void testInvalidFrameOnStream0() { @@ -403,21 +414,8 @@ static Stream>> prepareCalls() { rule.assertHasNoLeaks(); } - @Test - @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") - @SuppressWarnings("unchecked") - public void checkNoLeaksOnRacingTest() { - - racingCases() - .forEach( - a -> { - ((Runnable) a.get()[0]).run(); - checkNoLeaksOnRacing( - (Function>) a.get()[1], - (BiConsumer, ClientSocketRule>) a.get()[2]); - }); - } - + @ParameterizedTest + @MethodSource("racingCases") public void checkNoLeaksOnRacing( Function> initiator, BiConsumer, ClientSocketRule> runner) { @@ -437,7 +435,7 @@ public void evaluate() {} } Publisher payloadP = initiator.apply(clientSocketRule); - AssertSubscriber assertSubscriber = AssertSubscriber.create(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); if (payloadP instanceof Flux) { ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); @@ -450,14 +448,13 @@ public void evaluate() {} Assertions.assertThat(clientSocketRule.connection.getSent()) .allMatch(ReferenceCounted::release); - rule.assertHasNoLeaks(); + clientSocketRule.assertHasNoLeaks(); } } private static Stream racingCases() { return Stream.of( Arguments.of( - (Runnable) () -> System.out.println("RequestStream downstream cancellation case"), (Function>) (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), (BiConsumer, ClientSocketRule>) @@ -467,6 +464,7 @@ private static Stream racingCases() { metadata.writeCharSequence("abc", CharsetUtil.UTF_8); ByteBuf data = allocator.buffer(); data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); ByteBuf frame = PayloadFrameFlyweight.encode( @@ -475,7 +473,6 @@ private static Stream racingCases() { RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); }), Arguments.of( - (Runnable) () -> System.out.println("RequestChannel downstream cancellation case"), (Function>) (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), (BiConsumer, ClientSocketRule>) @@ -485,6 +482,7 @@ private static Stream racingCases() { metadata.writeCharSequence("abc", CharsetUtil.UTF_8); ByteBuf data = allocator.buffer(); data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); ByteBuf frame = PayloadFrameFlyweight.encode( @@ -493,79 +491,143 @@ private static Stream racingCases() { RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); }), Arguments.of( - (Runnable) () -> System.out.println("RequestChannel upstream cancellation 1"), (Function>) (rule) -> { ByteBufAllocator allocator = rule.alloc(); ByteBuf metadata = allocator.buffer(); - metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); ByteBuf data = allocator.buffer(); - data.writeCharSequence("def", CharsetUtil.UTF_8); - return rule.socket.requestChannel( - Flux.just(ByteBufPayload.create(data, metadata))); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + + return rule.socket.requestStream(payload); }, (BiConsumer, ClientSocketRule>) (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_STREAM, + "Expected first frame matches {" + + REQUEST_STREAM + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> { ByteBufAllocator allocator = rule.alloc(); - int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); - ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); - - RaceTestUtils.race( - () -> as.request(1), () -> rule.connection.addToReceivedBuffer(frame)); + return rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + sink.complete(); + return ++index; + })); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_CHANNEL, + "Expected first frame matches {" + + REQUEST_CHANNEL + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } }), Arguments.of( - (Runnable) () -> System.out.println("RequestChannel upstream cancellation 2"), (Function>) (rule) -> rule.socket.requestChannel( Flux.generate( () -> 1L, (index, sink) -> { - final Payload payload = - ByteBufPayload.create("d" + index, "m" + index); + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); sink.next(payload); return ++index; })), (BiConsumer, ClientSocketRule>) (as, rule) -> { ByteBufAllocator allocator = rule.alloc(); + as.request(1); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); - as.request(1); - RaceTestUtils.race( () -> as.request(Long.MAX_VALUE), () -> rule.connection.addToReceivedBuffer(frame)); }), Arguments.of( - (Runnable) () -> System.out.println("RequestChannel remote error"), (Function>) (rule) -> rule.socket.requestChannel( Flux.generate( () -> 1L, (index, sink) -> { - final Payload payload = - ByteBufPayload.create("d" + index, "m" + index); + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); sink.next(payload); return ++index; })), (BiConsumer, ClientSocketRule>) (as, rule) -> { ByteBufAllocator allocator = rule.alloc(); + as.request(1); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); ByteBuf frame = ErrorFrameFlyweight.encode(allocator, streamId, new RuntimeException("test")); - as.request(1); - RaceTestUtils.race( () -> as.request(Long.MAX_VALUE), () -> rule.connection.addToReceivedBuffer(frame)); }), Arguments.of( - (Runnable) () -> System.out.println("RequestResponse downstream cancellation"), (Function>) (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), (BiConsumer, ClientSocketRule>) @@ -575,6 +637,7 @@ private static Stream racingCases() { metadata.writeCharSequence("abc", CharsetUtil.UTF_8); ByteBuf data = allocator.buffer(); data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); ByteBuf frame = PayloadFrameFlyweight.encode( diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 48910b3a2..c19456548 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -94,6 +94,8 @@ public class RSocketResponderTest { @BeforeEach public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped(t -> {}); rule = new ServerSocketRule(); rule.apply( new Statement() { @@ -107,6 +109,7 @@ public void evaluate() {} @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); } @Test @@ -247,9 +250,7 @@ protected void hookOnSubscribe(Subscription subscription) { } @Test - @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { - ByteBufAllocator allocator = rule.alloc(); for (int i = 0; i < 10000; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); @@ -258,33 +259,32 @@ public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { new AbstractRSocket() { @Override public Flux requestChannel(Publisher payloads) { - ((Flux) payloads) - .doOnNext(ReferenceCountUtil::safeRelease) - .subscribe(assertSubscriber); + payloads.subscribe(assertSubscriber); return Flux.never(); } }, Integer.MAX_VALUE); rule.sendRequest(1, REQUEST_CHANNEL); + ByteBuf metadata1 = allocator.buffer(); - metadata1.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); ByteBuf data1 = allocator.buffer(); - data1.writeCharSequence("def", CharsetUtil.UTF_8); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); ByteBuf nextFrame1 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); ByteBuf metadata2 = allocator.buffer(); - metadata2.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); ByteBuf data2 = allocator.buffer(); - data2.writeCharSequence("def", CharsetUtil.UTF_8); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); ByteBuf nextFrame2 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); ByteBuf metadata3 = allocator.buffer(); - metadata3.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); ByteBuf data3 = allocator.buffer(); - data3.writeCharSequence("def", CharsetUtil.UTF_8); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); @@ -294,6 +294,8 @@ public Flux requestChannel(Publisher payloads) { }, assertSubscriber::cancel); + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); @@ -301,7 +303,6 @@ public Flux requestChannel(Publisher payloads) { } @Test - @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); @@ -341,7 +342,6 @@ public Flux requestChannel(Publisher payloads) { } @Test - @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); @@ -388,27 +388,25 @@ public Flux requestChannel(Publisher payloads) { } @Test - @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") public void - checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() - throws InterruptedException { + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); for (int i = 0; i < 10000; i++) { FluxSink[] sinks = new FluxSink[1]; - + AssertSubscriber assertSubscriber = AssertSubscriber.create(); rule.setAcceptingSocket( new AbstractRSocket() { @Override public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); return Flux.create( - sink -> { - sinks[0] = sink; - }, - FluxSink.OverflowStrategy.IGNORE) - .mergeWith(payloads); + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE); } }, 1); @@ -416,23 +414,23 @@ public Flux requestChannel(Publisher payloads) { rule.sendRequest(1, REQUEST_CHANNEL); ByteBuf metadata1 = allocator.buffer(); - metadata1.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); ByteBuf data1 = allocator.buffer(); - data1.writeCharSequence("def", CharsetUtil.UTF_8); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); ByteBuf nextFrame1 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); ByteBuf metadata2 = allocator.buffer(); - metadata2.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); ByteBuf data2 = allocator.buffer(); - data2.writeCharSequence("def", CharsetUtil.UTF_8); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); ByteBuf nextFrame2 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); ByteBuf metadata3 = allocator.buffer(); - metadata3.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); ByteBuf data3 = allocator.buffer(); - data3.writeCharSequence("def", CharsetUtil.UTF_8); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); @@ -454,13 +452,12 @@ public Flux requestChannel(Publisher payloads) { parallel); Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); - + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } } @Test - @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); @@ -585,23 +582,23 @@ public Flux requestChannel(Publisher payloads) { ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); ByteBuf metadata1 = allocator.buffer(); - metadata1.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); ByteBuf data1 = allocator.buffer(); - data1.writeCharSequence("def", CharsetUtil.UTF_8); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); ByteBuf nextFrame1 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); ByteBuf metadata2 = allocator.buffer(); - metadata2.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); ByteBuf data2 = allocator.buffer(); - data2.writeCharSequence("def", CharsetUtil.UTF_8); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); ByteBuf nextFrame2 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); ByteBuf metadata3 = allocator.buffer(); - metadata3.writeCharSequence("abc", CharsetUtil.UTF_8); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); ByteBuf data3 = allocator.buffer(); - data3.writeCharSequence("def", CharsetUtil.UTF_8); + data3.writeCharSequence("de3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index 5e94935c5..63300c718 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -26,7 +26,13 @@ public final class ByteBufRepresentation extends StandardRepresentation { protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { try { - return ByteBufUtil.prettyHexDump((ByteBuf) object); + String normalBufferString = object.toString(); + String prettyHexDump = ByteBufUtil.prettyHexDump((ByteBuf) object); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); } catch (IllegalReferenceCountException e) { // noops }