diff --git a/reaktive/src/commonMain/kotlin/com/badoo/reaktive/observable/RefCount.kt b/reaktive/src/commonMain/kotlin/com/badoo/reaktive/observable/RefCount.kt index b865fca2b..5f91bbbaf 100644 --- a/reaktive/src/commonMain/kotlin/com/badoo/reaktive/observable/RefCount.kt +++ b/reaktive/src/commonMain/kotlin/com/badoo/reaktive/observable/RefCount.kt @@ -2,10 +2,9 @@ package com.badoo.reaktive.observable import com.badoo.reaktive.disposable.CompositeDisposable import com.badoo.reaktive.disposable.Disposable +import com.badoo.reaktive.disposable.SerialDisposable import com.badoo.reaktive.disposable.plusAssign -import com.badoo.reaktive.utils.atomic.AtomicInt -import com.badoo.reaktive.utils.atomic.AtomicReference -import com.badoo.reaktive.utils.atomic.getAndChange +import com.badoo.reaktive.utils.lock.Lock /** * Returns an [Observable] that connects to this [ConnectableObservable] when the number @@ -16,23 +15,15 @@ import com.badoo.reaktive.utils.atomic.getAndChange fun ConnectableObservable.refCount(subscriberCount: Int = 1): Observable { require(subscriberCount > 0) - val subscribeCount = AtomicInt() - val disposable = AtomicReference(null) + var subscribeCount = 0 + val lock = Lock() + val connectionDisposable = SerialDisposable() return observable { emitter -> val disposables = CompositeDisposable() emitter.setDisposable(disposables) - disposables += - Disposable { - if (subscribeCount.addAndGet(-1) == 0) { - disposable - .getAndChange { null } - ?.dispose() - } - } - - val shouldConnect = subscribeCount.addAndGet(1) == subscriberCount + val shouldConnect = lock.synchronized { ++subscribeCount == subscriberCount } this@refCount.subscribe( object : ObservableObserver, ObservableCallbacks by emitter { @@ -43,9 +34,16 @@ fun ConnectableObservable.refCount(subscriberCount: Int = 1): Observable< ) if (shouldConnect) { - this@refCount.connect { - disposable.value = it - } + this@refCount.connect(connectionDisposable::set) } + + disposables += + Disposable { + lock.synchronized { + if (--subscribeCount == 0) { + connectionDisposable.set(null) + } + } + } } } diff --git a/reaktive/src/commonTest/kotlin/com/badoo/reaktive/observable/RefCountTest.kt b/reaktive/src/commonTest/kotlin/com/badoo/reaktive/observable/RefCountTest.kt index f05751d69..edb31621e 100644 --- a/reaktive/src/commonTest/kotlin/com/badoo/reaktive/observable/RefCountTest.kt +++ b/reaktive/src/commonTest/kotlin/com/badoo/reaktive/observable/RefCountTest.kt @@ -82,6 +82,40 @@ class RefCountTest { assertTrue(disposable.isDisposed) } + @Test + fun connects_to_upstream_WHEN_subscriberCount_is_1_and_subscribed_and_disposed_in_onSubscribe() { + var isConnected = false + val upstream = testUpstream(connect = { isConnected = true }) + val refCount = upstream.refCount(subscriberCount = 1) + + refCount.subscribe( + object : DefaultObservableObserver { + override fun onSubscribe(disposable: Disposable) { + disposable.dispose() + } + } + ) + + assertTrue(isConnected) + } + + @Test + fun disconnects_from_upstream_WHEN_subscriberCount_is_1_and_subscribed_and_disposed_in_onSubscribe() { + val disposable = Disposable() + val upstream = testUpstream(connect = { onConnect -> onConnect?.invoke(disposable) }) + val refCount = upstream.refCount(subscriberCount = 1) + + refCount.subscribe( + object : DefaultObservableObserver { + override fun onSubscribe(disposable: Disposable) { + disposable.dispose() + } + } + ) + + assertTrue(disposable.isDisposed) + } + @Test fun disconnects_from_upstream_WHEN_subscriberCount_is_2_and_all_subscribers_unsubscribed() { val disposable = Disposable() diff --git a/reaktive/src/jvmNativeCommonTest/kotlin/com/badoo/reaktive/observable/RefCountThreadingTest.kt b/reaktive/src/jvmNativeCommonTest/kotlin/com/badoo/reaktive/observable/RefCountThreadingTest.kt new file mode 100644 index 000000000..c15e61e49 --- /dev/null +++ b/reaktive/src/jvmNativeCommonTest/kotlin/com/badoo/reaktive/observable/RefCountThreadingTest.kt @@ -0,0 +1,72 @@ +package com.badoo.reaktive.observable + +import com.badoo.reaktive.disposable.Disposable +import com.badoo.reaktive.test.doInBackground +import com.badoo.reaktive.test.observable.test +import com.badoo.reaktive.utils.lock.ConditionLock +import com.badoo.reaktive.utils.lock.synchronized +import com.badoo.reaktive.utils.lock.waitFor +import com.badoo.reaktive.utils.lock.waitForOrFail +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.time.Duration.Companion.seconds + +class RefCountThreadingTest { + + @Test + fun does_not_connect_second_time_concurrently_while_disconnecting() { + val lock = ConditionLock() + var isDisconnecting = false + var isSecondTime = false + var isConnectedSecondTimeConcurrently = false + + val disposable = + Disposable { + lock.synchronized { + isDisconnecting = true + isSecondTime = true + lock.signal() + lock.waitFor(timeout = 1.seconds) { false } + isDisconnecting = false + } + } + + val upstream = + testUpstream( + connect = { onConnect -> + lock.synchronized { + if (!isSecondTime) { + onConnect?.invoke(disposable) + } else { + isConnectedSecondTimeConcurrently = isDisconnecting + } + } + } + ) + + val refCount = upstream.refCount(subscriberCount = 1) + val observer = refCount.test() + doInBackground { observer.dispose() } + + lock.synchronized { + lock.waitForOrFail { !isSecondTime } + } + + refCount.test() + + assertFalse(isConnectedSecondTimeConcurrently) + } + + private fun testUpstream( + connect: (onConnect: ((Disposable) -> Unit)?) -> Unit = {}, + ): ConnectableObservable = + object : ConnectableObservable { + override fun connect(onConnect: ((Disposable) -> Unit)?) { + connect.invoke(onConnect) + } + + override fun subscribe(observer: ObservableObserver) { + observer.onSubscribe(Disposable()) + } + } +}