diff --git a/gradle.properties b/gradle.properties index 9008d313b..cd6123e21 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,4 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=0.11.13.BUILD-SNAPSHOT \ No newline at end of file + +version=0.11.12.BUILD-SNAPSHOT diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java index 6b4626f9b..b8ec5b863 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java @@ -16,7 +16,6 @@ package io.rsocket.internal; -import io.netty.util.ReferenceCountUtil; import java.util.Objects; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -24,10 +23,12 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; import reactor.util.annotation.Nullable; +import reactor.util.context.Context; public final class SwitchTransformFlux extends Flux { @@ -46,30 +47,53 @@ public int getPrefetch() { } @Override + @SuppressWarnings("unchecked") public void subscribe(CoreSubscriber actual) { - source.subscribe(new SwitchTransformMain<>(actual, transformer)); + if (actual instanceof Fuseable.ConditionalSubscriber) { + source.subscribe( + new SwitchTransformConditionalOperator<>( + (Fuseable.ConditionalSubscriber) actual, transformer)); + return; + } + source.subscribe(new SwitchTransformOperator<>(actual, transformer)); } - static final class SwitchTransformMain implements CoreSubscriber, Scannable { + static final class SwitchTransformOperator extends Flux + implements CoreSubscriber, Subscription, Scannable { - final CoreSubscriber actual; + final CoreSubscriber outer; final BiFunction, Publisher> transformer; - final SwitchTransformInner inner; Subscription s; + Throwable throwable; + + volatile boolean done; + volatile T first; + + volatile CoreSubscriber inner; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater INNER = + AtomicReferenceFieldUpdater.newUpdater( + SwitchTransformOperator.class, CoreSubscriber.class, "inner"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(SwitchTransformOperator.class, "wip"); volatile int once; @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformMain.class, "once"); + static final AtomicIntegerFieldUpdater ONCE = + AtomicIntegerFieldUpdater.newUpdater(SwitchTransformOperator.class, "once"); - SwitchTransformMain( - CoreSubscriber actual, + SwitchTransformOperator( + CoreSubscriber outer, BiFunction, Publisher> transformer) { - this.actual = actual; + this.outer = outer; this.transformer = transformer; - this.inner = new SwitchTransformInner<>(this); } @Override @@ -81,6 +105,48 @@ public Object scanUnsafe(Attr key) { return null; } + @Override + public Context currentContext() { + CoreSubscriber actual = inner; + + if (actual != null) { + return actual.currentContext(); + } + + return outer.currentContext(); + } + + @Override + public void cancel() { + if (s != Operators.cancelledSubscription()) { + Subscription s = this.s; + this.s = Operators.cancelledSubscription(); + + if (WIP.getAndIncrement(this) == 0) { + INNER.lazySet(this, null); + + T f = first; + if (f != null) { + first = null; + Operators.onDiscard(f, currentContext()); + } + } + + s.cancel(); + } + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + INNER.lazySet(this, actual); + actual.onSubscribe(this); + } else { + Operators.error( + actual, new IllegalStateException("SwitchTransform allows only one Subscriber")); + } + } + @Override public void onSubscribe(Subscription s) { if (Operators.validate(this.s, s)) { @@ -91,161 +157,423 @@ public void onSubscribe(Subscription s) { @Override public void onNext(T t) { - if (isCanceled()) { + if (done) { + Operators.onNextDropped(t, currentContext()); return; } - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + CoreSubscriber i = inner; + + if (i == null) { try { - inner.first = t; + first = t; Publisher result = Objects.requireNonNull( - transformer.apply(t, inner), "The transformer returned a null value"); - result.subscribe(actual); + transformer.apply(t, this), "The transformer returned a null value"); + result.subscribe(outer); return; } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, actual.currentContext())); - ReferenceCountUtil.safeRelease(t); + onError(Operators.onOperatorError(s, e, t, currentContext())); return; } } - inner.onNext(t); + i.onNext(t); } @Override public void onError(Throwable t) { - if (isCanceled()) { + if (done) { + Operators.onErrorDropped(t, currentContext()); return; } - if (once != 0) { - inner.onError(t); + throwable = t; + done = true; + CoreSubscriber i = inner; + + if (i != null) { + if (first == null) { + drainRegular(); + } } else { - actual.onSubscribe(Operators.emptySubscription()); - actual.onError(t); + Operators.error(outer, t); } } @Override public void onComplete() { - if (isCanceled()) { + if (done) { return; } - if (once != 0) { - inner.onComplete(); + done = true; + CoreSubscriber i = inner; + + if (i != null) { + if (first == null) { + drainRegular(); + } } else { - actual.onSubscribe(Operators.emptySubscription()); - actual.onComplete(); + Operators.complete(outer); } } - boolean isCanceled() { - return s == Operators.cancelledSubscription(); + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (first != null && drainRegular() && n != Long.MAX_VALUE) { + if (--n > 0) { + s.request(n); + } + } else { + s.request(n); + } + } } - void cancel() { - s.cancel(); - s = Operators.cancelledSubscription(); + boolean drainRegular() { + if (WIP.getAndIncrement(this) != 0) { + return false; + } + + T f = first; + int m = 1; + boolean sent = false; + Subscription s = this.s; + CoreSubscriber a = inner; + + for (; ; ) { + if (f != null) { + first = null; + + if (s == Operators.cancelledSubscription()) { + Operators.onDiscard(f, a.currentContext()); + return true; + } + + a.onNext(f); + f = null; + sent = true; + } + + if (s == Operators.cancelledSubscription()) { + return sent; + } + + if (done) { + Throwable t = throwable; + if (t != null) { + a.onError(t); + } else { + a.onComplete(); + } + return sent; + } + + m = WIP.addAndGet(this, -m); + + if (m == 0) { + return sent; + } + } } } - static final class SwitchTransformInner extends Flux implements Scannable, Subscription { + static final class SwitchTransformConditionalOperator extends Flux + implements Fuseable.ConditionalSubscriber, Subscription, Scannable { + + final Fuseable.ConditionalSubscriber outer; + final BiFunction, Publisher> transformer; + + Subscription s; + Throwable throwable; - final SwitchTransformMain parent; + volatile boolean done; + volatile T first; - volatile CoreSubscriber actual; + volatile Fuseable.ConditionalSubscriber inner; @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater ACTUAL = - AtomicReferenceFieldUpdater.newUpdater( - SwitchTransformInner.class, CoreSubscriber.class, "actual"); + static final AtomicReferenceFieldUpdater< + SwitchTransformConditionalOperator, Fuseable.ConditionalSubscriber> + INNER = + AtomicReferenceFieldUpdater.newUpdater( + SwitchTransformConditionalOperator.class, + Fuseable.ConditionalSubscriber.class, + "inner"); - volatile V first; + volatile int wip; @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater FIRST = - AtomicReferenceFieldUpdater.newUpdater(SwitchTransformInner.class, Object.class, "first"); + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(SwitchTransformConditionalOperator.class, "wip"); volatile int once; @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformInner.class, "once"); + static final AtomicIntegerFieldUpdater ONCE = + AtomicIntegerFieldUpdater.newUpdater(SwitchTransformConditionalOperator.class, "once"); - SwitchTransformInner(SwitchTransformMain parent) { - this.parent = parent; + SwitchTransformConditionalOperator( + Fuseable.ConditionalSubscriber outer, + BiFunction, Publisher> transformer) { + this.outer = outer; + this.transformer = transformer; } - public void onNext(V t) { - CoreSubscriber a = actual; + @Override + @Nullable + public Object scanUnsafe(Attr key) { + if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); + if (key == Attr.PREFETCH) return 1; - if (a != null) { - a.onNext(t); - } + return null; } - public void onError(Throwable t) { - CoreSubscriber a = actual; + @Override + public Context currentContext() { + CoreSubscriber actual = inner; - if (a != null) { - a.onError(t); + if (actual != null) { + return actual.currentContext(); } + + return outer.currentContext(); } - public void onComplete() { - CoreSubscriber a = actual; + @Override + public void cancel() { + if (s != Operators.cancelledSubscription()) { + Subscription s = this.s; + this.s = Operators.cancelledSubscription(); + + if (WIP.getAndIncrement(this) == 0) { + INNER.lazySet(this, null); + + T f = first; + if (f != null) { + first = null; + Operators.onDiscard(f, currentContext()); + } + } - if (a != null) { - a.onComplete(); + s.cancel(); } } @Override - public void subscribe(CoreSubscriber actual) { + @SuppressWarnings("unchecked") + public void subscribe(CoreSubscriber actual) { if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - ACTUAL.lazySet(this, actual); + if (actual instanceof Fuseable.ConditionalSubscriber) { + INNER.lazySet(this, (Fuseable.ConditionalSubscriber) actual); + } else { + INNER.lazySet(this, new ConditionalSubscriberAdapter<>(actual)); + } actual.onSubscribe(this); } else { - actual.onError(new IllegalStateException("SwitchTransform allows only one Subscriber")); + Operators.error( + actual, new IllegalStateException("SwitchTransform allows only one Subscriber")); } } @Override - public void request(long n) { - V f = first; + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(1); + } + } + + @Override + public void onNext(T t) { + if (done) { + Operators.onNextDropped(t, currentContext()); + return; + } + + CoreSubscriber i = inner; + + if (i == null) { + try { + first = t; + Publisher result = + Objects.requireNonNull( + transformer.apply(t, this), "The transformer returned a null value"); + result.subscribe(outer); + return; + } catch (Throwable e) { + onError(Operators.onOperatorError(s, e, t, currentContext())); + return; + } + } - if (f != null && FIRST.compareAndSet(this, f, null)) { - actual.onNext(f); + i.onNext(t); + } + + @Override + public boolean tryOnNext(T t) { + if (done) { + Operators.onNextDropped(t, currentContext()); + return false; + } + + Fuseable.ConditionalSubscriber i = inner; - long r = Operators.addCap(n, -1); - if (r > 0) { - parent.s.request(r); + if (i == null) { + try { + first = t; + Publisher result = + Objects.requireNonNull( + transformer.apply(t, this), "The transformer returned a null value"); + result.subscribe(outer); + return true; + } catch (Throwable e) { + onError(Operators.onOperatorError(s, e, t, currentContext())); + return false; + } + } + + return i.tryOnNext(t); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + throwable = t; + done = true; + CoreSubscriber i = inner; + + if (i != null) { + if (first == null) { + drainRegular(); } } else { - parent.s.request(n); + Operators.error(outer, t); } } @Override - public void cancel() { - actual = null; - first = null; - parent.cancel(); + public void onComplete() { + if (done) { + return; + } + + done = true; + CoreSubscriber i = inner; + + if (i != null) { + if (first == null) { + drainRegular(); + } + } else { + Operators.complete(outer); + } } @Override - @Nullable - public Object scanUnsafe(Attr key) { - if (key == Attr.PARENT) return parent; - if (key == Attr.ACTUAL) return actual(); + public void request(long n) { + if (Operators.validate(n)) { + if (first != null && drainRegular() && n != Long.MAX_VALUE) { + if (--n > 0) { + s.request(n); + } + } else { + s.request(n); + } + } + } - return null; + boolean drainRegular() { + if (WIP.getAndIncrement(this) != 0) { + return false; + } + + T f = first; + int m = 1; + boolean sent = false; + Subscription s = this.s; + CoreSubscriber a = inner; + + for (; ; ) { + if (f != null) { + first = null; + + if (s == Operators.cancelledSubscription()) { + Operators.onDiscard(f, a.currentContext()); + return true; + } + + a.onNext(f); + f = null; + sent = true; + } + + if (s == Operators.cancelledSubscription()) { + return sent; + } + + if (done) { + Throwable t = throwable; + if (t != null) { + a.onError(t); + } else { + a.onComplete(); + } + return sent; + } + + m = WIP.addAndGet(this, -m); + + if (m == 0) { + return sent; + } + } } + } + + static final class ConditionalSubscriberAdapter implements Fuseable.ConditionalSubscriber { - public CoreSubscriber actual() { - return actual; + final CoreSubscriber delegate; + + ConditionalSubscriberAdapter(CoreSubscriber delegate) { + this.delegate = delegate; + } + + @Override + public Context currentContext() { + return delegate.currentContext(); + } + + @Override + public void onSubscribe(Subscription s) { + delegate.onSubscribe(s); + } + + @Override + public void onNext(T t) { + delegate.onNext(t); + } + + @Override + public void onError(Throwable t) { + delegate.onError(t); + } + + @Override + public void onComplete() { + delegate.onComplete(); + } + + @Override + public boolean tryOnNext(T t) { + delegate.onNext(t); + return true; } } } diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java index a8e60d02b..7e1e68178 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java @@ -11,6 +11,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; + import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Arrays; diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java b/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java index 9159641e3..2297d6bfa 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java @@ -1,20 +1,241 @@ package io.rsocket.internal; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; + import java.time.Duration; +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.Assert; +import org.junit.Assume; import org.junit.Test; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; public class SwitchTransformFluxTest { + @Test + public void shouldBeAbleToCancelSubscription() throws InterruptedException { + for (int j = 0; j < 10; j++) { + ArrayList capturedElements = new ArrayList<>(); + ArrayList capturedCompletions = new ArrayList<>(); + for (int i = 0; i < 1000; i++) { + TestPublisher publisher = TestPublisher.createCold(); + AtomicLong captureElement = new AtomicLong(0L); + AtomicBoolean captureCompletion = new AtomicBoolean(false); + AtomicLong requested = new AtomicLong(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::addAndGet) + .doOnCancel(latch::countDown) + .transform( + flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)); + + publisher.next(1L); + + switchTransformed.subscribe( + captureElement::set, + __ -> {}, + () -> captureCompletion.set(true), + s -> + new Thread( + () -> + RaceTestUtils.race( + publisher::complete, + () -> + RaceTestUtils.race( + s::cancel, () -> s.request(1), Schedulers.parallel()), + Schedulers.parallel())) + .start()); + + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + Assert.assertEquals(requested.get(), 1L); + capturedElements.add(captureElement.get()); + capturedCompletions.add(captureCompletion.get()); + } + + Assume.assumeThat(capturedElements, hasItem(equalTo(0L))); + Assume.assumeThat(capturedCompletions, hasItem(equalTo(false))); + } + } + + @Test + public void shouldRequestExpectedAmountOfElements() throws InterruptedException { + TestPublisher publisher = TestPublisher.createCold(); + AtomicLong capture = new AtomicLong(); + AtomicLong requested = new AtomicLong(); + CountDownLatch latch = new CountDownLatch(1); + Flux switchTransformed = + publisher + .flux() + .doOnRequest(requested::addAndGet) + .transform(flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)); + + publisher.next(1L); + + switchTransformed.subscribe( + capture::set, + __ -> {}, + latch::countDown, + s -> { + for (int i = 0; i < 10000; i++) { + RaceTestUtils.race(() -> s.request(1), () -> s.request(1)); + } + RaceTestUtils.race(publisher::complete, publisher::complete); + }); + + latch.await(5, TimeUnit.SECONDS); + + Assert.assertEquals(capture.get(), 1L); + Assert.assertEquals(requested.get(), 20000L); + } + + @Test + public void shouldReturnCorrectContextOnEmptySource() { + Flux switchTransformed = + Flux.empty() + .transform(flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)) + .subscriberContext(Context.of("a", "c")) + .subscriberContext(Context.of("c", "d")); + + StepVerifier.create(switchTransformed, 0) + .expectSubscription() + .thenRequest(1) + .expectAccessibleContext() + .contains("a", "c") + .contains("c", "d") + .then() + .expectComplete() + .verify(); + } + + @Test + public void shouldNotFailOnIncorrectPublisherBehavior() { + TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.CLEANUP_ON_TERMINATE); + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, + (first, innerFlux) -> innerFlux.subscriberContext(Context.of("a", "b")))); + + StepVerifier.create( + new Flux() { + @Override + public void subscribe(CoreSubscriber actual) { + switchTransformed.subscribe(actual); + publisher.next(1L); + } + }, + 0) + .thenRequest(1) + .expectNext(1L) + .thenRequest(1) + .then(() -> publisher.next(2L)) + .expectNext(2L) + .then(() -> publisher.error(new RuntimeException())) + .then(() -> publisher.error(new RuntimeException())) + .then(() -> publisher.error(new RuntimeException())) + .then(() -> publisher.error(new RuntimeException())) + .expectError() + .verifyThenAssertThat() + .hasDroppedErrors(3) + .tookLessThan(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + + // @Test + // public void shouldNotFailOnIncorrePu + + @Test + public void shouldBeAbleToAccessUpstreamContext() { + TestPublisher publisher = TestPublisher.createCold(); + + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, + (first, innerFlux) -> + innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")))) + .subscriberContext(Context.of("a", "c")) + .subscriberContext(Context.of("c", "d")); + + publisher.next(1L); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectNext("1") + .thenRequest(1) + .then(() -> publisher.next(2L)) + .expectNext("2") + .expectAccessibleContext() + .contains("a", "b") + .contains("c", "d") + .then() + .then(publisher::complete) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + + @Test + public void shouldNotHangWhenOneElementUpstream() { + TestPublisher publisher = TestPublisher.createCold(); + + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, + (first, innerFlux) -> + innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")))) + .subscriberContext(Context.of("a", "c")) + .subscriberContext(Context.of("c", "d")); + + publisher.next(1L); + publisher.complete(); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectNext("1") + .expectComplete() + .verify(Duration.ofSeconds(10)); + + publisher.assertWasRequested(); + publisher.assertNoRequestOverflow(); + } + @Test public void backpressureTest() { TestPublisher publisher = TestPublisher.createCold(); + AtomicLong requested = new AtomicLong(); Flux switchTransformed = publisher .flux() + .doOnRequest(requested::addAndGet) .transform( flux -> new SwitchTransformFlux<>( @@ -34,6 +255,74 @@ public void backpressureTest() { publisher.assertWasRequested(); publisher.assertNoRequestOverflow(); + + Assert.assertEquals(2L, requested.get()); + } + + @Test + public void backpressureConditionalTest() { + Flux publisher = Flux.range(0, 10000); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new SwitchTransformFlux<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) + .filter(e -> false); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + Assert.assertEquals(2L, requested.get()); + } + + @Test + public void backpressureHiddenConditionalTest() { + Flux publisher = Flux.range(0, 10000); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new SwitchTransformFlux<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf).hide())) + .filter(e -> false); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + Assert.assertEquals(10001L, requested.get()); + } + + @Test + public void backpressureDrawbackOnConditionalInTransformTest() { + Flux publisher = Flux.range(0, 10000); + AtomicLong requested = new AtomicLong(); + + Flux switchTransformed = + publisher + .doOnRequest(requested::addAndGet) + .transform( + flux -> + new SwitchTransformFlux<>( + flux, + (first, innerFlux) -> innerFlux.map(String::valueOf).filter(e -> false))); + + StepVerifier.create(switchTransformed, 0) + .thenRequest(1) + .expectComplete() + .verify(Duration.ofSeconds(10)); + + Assert.assertEquals(10001L, requested.get()); } @Test @@ -98,11 +387,58 @@ public void shouldBeAbleToBeCancelledProperly() { new SwitchTransformFlux<>( flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - publisher.emit(1, 2, 3, 4, 5); + publisher.next(1); + + StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + } + + @Test + public void shouldBeAbleToCatchDiscardedElement() { + TestPublisher publisher = TestPublisher.createCold(); + Integer[] discarded = new Integer[1]; + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) + .doOnDiscard(Integer.class, e -> discarded[0] = e); + + publisher.next(1); StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); publisher.assertCancelled(); publisher.assertWasRequested(); + + Assert.assertArrayEquals(new Integer[] {1}, discarded); + } + + @Test + public void shouldBeAbleToCatchDiscardedElementInCaseOfConditional() { + TestPublisher publisher = TestPublisher.createCold(); + Integer[] discarded = new Integer[1]; + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) + .filter(t -> true) + .doOnDiscard(Integer.class, e -> discarded[0] = e); + + publisher.next(1); + + StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + + Assert.assertArrayEquals(new Integer[] {1}, discarded); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java index 9ef1f394b..2f54ddb50 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java @@ -41,7 +41,6 @@ public PingHandler(byte[] data) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { - setup.release(); return Mono.just( new AbstractRSocket() { @Override