From b5740cd774c5bdb714c00d64338e1685215c1584 Mon Sep 17 00:00:00 2001 From: KaylaBrady <31781298+KaylaBrady@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:53:52 -0500 Subject: [PATCH] feat(android): leave / rejoin predictions & alerts after backgrounding --- .../android/state/SubscribeToAlertsTest.kt | 92 +++++---- .../state/SubscribeToPredictionsTest.kt | 177 ++++++++---------- .../android/state/subscribeToAlerts.kt | 58 +++--- .../android/state/subscribeToPredictions.kt | 115 ++++++------ .../mbta_app/repositories/AlertsRepository.kt | 10 +- 5 files changed, 227 insertions(+), 225 deletions(-) diff --git a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToAlertsTest.kt b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToAlertsTest.kt index 4aef266a8..8775bd437 100644 --- a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToAlertsTest.kt +++ b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToAlertsTest.kt @@ -1,36 +1,21 @@ package com.mbta.tid.mbta_app.android.state +import androidx.compose.runtime.CompositionLocalProvider +import androidx.compose.runtime.getValue +import androidx.compose.runtime.setValue import androidx.compose.ui.test.junit4.createComposeRule -import androidx.lifecycle.ViewModel -import androidx.lifecycle.ViewModelProvider -import androidx.lifecycle.ViewModelStore -import com.mbta.tid.mbta_app.android.util.TimerViewModel +import androidx.lifecycle.Lifecycle +import androidx.lifecycle.compose.LocalLifecycleOwner +import androidx.lifecycle.testing.TestLifecycleOwner import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder import com.mbta.tid.mbta_app.model.response.AlertsStreamDataResponse -import com.mbta.tid.mbta_app.model.response.ApiResult -import com.mbta.tid.mbta_app.repositories.IAlertsRepository +import com.mbta.tid.mbta_app.repositories.MockAlertsRepository import kotlin.test.assertEquals -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest import org.junit.Assert import org.junit.Rule import org.junit.Test -class MockAlertsRepository(private val scope: CoroutineScope) : IAlertsRepository { - lateinit var alertsStreamDataResponse: AlertsStreamDataResponse - var disconnectHook: () -> Unit = { println("original disconnect hook called") } - - override fun connect(onReceive: (ApiResult) -> Unit) { - scope.launch { onReceive(ApiResult.Ok(alertsStreamDataResponse)) } - } - - override fun disconnect() { - disconnectHook() - } -} - class SubscribeToAlertsTest { @get:Rule val composeRule = createComposeRule() @@ -42,33 +27,58 @@ class SubscribeToAlertsTest { header = "Alert 1" description = "Description 1" } + + var connectCount = 0 val alertsStreamDataResponse = AlertsStreamDataResponse(builder) - val alertsRepo = MockAlertsRepository(this.backgroundScope) - alertsRepo.alertsStreamDataResponse = alertsStreamDataResponse + val alertsRepo = MockAlertsRepository(alertsStreamDataResponse, { connectCount += 1 }) var actualData: AlertsStreamDataResponse? = null composeRule.setContent { actualData = subscribeToAlerts(alertsRepo) } - composeRule.awaitIdle() + composeRule.waitUntil { connectCount == 1 } assertEquals(alertsStreamDataResponse, actualData) } @Test - fun testAlertsOnClear() = runTest { - var disconnectCalled = false - val mockAlertsRepository = MockAlertsRepository(this.backgroundScope) - mockAlertsRepository.disconnectHook = { disconnectCalled = true } - val viewModelStore = ViewModelStore() - val viewModelProvider = - ViewModelProvider( - viewModelStore, - object : ViewModelProvider.Factory { - override fun create(modelClass: Class): T { - return AlertsViewModel(mockAlertsRepository, TimerViewModel(1.seconds)) as T - } - } + fun testDisconnectsOnPause() = runTest { + val lifecycleOwner = TestLifecycleOwner(Lifecycle.State.RESUMED) + + var connectCount = 0 + var disconnectCount = 0 + + val builder = ObjectCollectionBuilder() + builder.alert { + id = "1" + header = "Alert 1" + description = "Description 1" + } + + val alertsStreamDataResponse = AlertsStreamDataResponse(builder) + val alertsRepo = + MockAlertsRepository( + alertsStreamDataResponse, + { connectCount += 1 }, + { disconnectCount += 1 } ) - viewModelProvider.get(AlertsViewModel::class) - viewModelStore.clear() - Assert.assertEquals(true, disconnectCalled) + + var actualData: AlertsStreamDataResponse? = null + + composeRule.setContent { + CompositionLocalProvider(LocalLifecycleOwner provides lifecycleOwner) { + actualData = subscribeToAlerts(alertsRepo) + } + } + + composeRule.waitUntil { connectCount == 1 } + Assert.assertEquals(0, disconnectCount) + + composeRule.runOnIdle { lifecycleOwner.handleLifecycleEvent(Lifecycle.Event.ON_PAUSE) } + + composeRule.waitUntil { disconnectCount == 1 } + Assert.assertEquals(1, connectCount) + + composeRule.runOnIdle { lifecycleOwner.handleLifecycleEvent(Lifecycle.Event.ON_RESUME) } + + composeRule.waitUntil { connectCount == 2 } + Assert.assertEquals(1, disconnectCount) } } diff --git a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToPredictionsTest.kt b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToPredictionsTest.kt index 4508ec2cf..160df2c0a 100644 --- a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToPredictionsTest.kt +++ b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/state/SubscribeToPredictionsTest.kt @@ -1,133 +1,108 @@ package com.mbta.tid.mbta_app.android.state import androidx.activity.ComponentActivity +import androidx.compose.runtime.CompositionLocalProvider import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.test.junit4.createAndroidComposeRule -import androidx.lifecycle.ViewModel -import androidx.lifecycle.ViewModelProvider -import androidx.lifecycle.ViewModelStore -import com.mbta.tid.mbta_app.android.util.TimerViewModel +import androidx.lifecycle.Lifecycle +import androidx.lifecycle.compose.LocalLifecycleOwner +import androidx.lifecycle.testing.TestLifecycleOwner import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder -import com.mbta.tid.mbta_app.model.response.ApiResult import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse -import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse -import com.mbta.tid.mbta_app.repositories.IPredictionsRepository -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.launch +import com.mbta.tid.mbta_app.repositories.MockPredictionsRepository import kotlinx.coroutines.test.runTest -import kotlinx.datetime.Instant import org.junit.Assert.assertEquals -import org.junit.Assert.assertNull import org.junit.Rule import org.junit.Test -class MockPredictionsRepository(private val scope: CoroutineScope) : IPredictionsRepository { - val stopIdsChannel = Channel>() - lateinit var onJoin: (ApiResult) -> Unit - lateinit var onMessage: (ApiResult) -> Unit - var disconnectHook: () -> Unit = { println("original disconnect hook called") } - - override fun connect( - stopIds: List, - onReceive: (ApiResult) -> Unit - ) { - /* null-op */ - } - - override fun connectV2( - stopIds: List, - onJoin: (ApiResult) -> Unit, - onMessage: (ApiResult) -> Unit - ) { - - this.onJoin = onJoin - scope.launch { stopIdsChannel.send(stopIds) } - } - - override var lastUpdated: Instant? = null - - override fun shouldForgetPredictions(predictionCount: Int) = false - - override fun disconnect() { - disconnectHook() - } -} - class SubscribeToPredictionsTest { @get:Rule val composeTestRule = createAndroidComposeRule() @Test fun testPredictions() = runTest { - fun buildSomePredictions(): PredictionsByStopJoinResponse { - val objects = ObjectCollectionBuilder() - objects.prediction() - objects.prediction() - return PredictionsByStopJoinResponse(objects) - } - val predictionsRepo = MockPredictionsRepository(this) + val objects = ObjectCollectionBuilder() + objects.prediction() + objects.prediction() + val predictionsOnJoin = PredictionsByStopJoinResponse(objects) + + var connectProps: List? = null + var disconnectCount = 0 + + val predictionsRepo = + MockPredictionsRepository( + {}, + { stops -> connectProps = stops }, + { disconnectCount += 1 }, + null, + predictionsOnJoin + ) - var stopIds by mutableStateOf(listOf("place-a")) - var unmounted by mutableStateOf(false) + var stopIds = mutableStateOf(listOf("place-a")) var predictions: PredictionsStreamDataResponse? = PredictionsStreamDataResponse(ObjectCollectionBuilder()) + composeTestRule.setContent { - if (!unmounted) predictions = subscribeToPredictions(stopIds, predictionsRepo) + var stopIds by remember { stopIds } + predictions = subscribeToPredictions(stopIds, predictionsRepo) } - composeTestRule.awaitIdle() - assertEquals(listOf("place-a"), predictionsRepo.stopIdsChannel.receive()) - assertNull(predictions) - - val expectedPredictions1 = buildSomePredictions() - predictionsRepo.onJoin(ApiResult.Ok(expectedPredictions1)) - composeTestRule.awaitIdle() - assertEquals(expectedPredictions1.toPredictionsStreamDataResponse(), predictions) - - stopIds = listOf("place-b") - composeTestRule.awaitIdle() - assertEquals(listOf("place-b"), predictionsRepo.stopIdsChannel.receive()) - predictionsRepo.onJoin(ApiResult.Ok(expectedPredictions1)) - composeTestRule.awaitIdle() - assertEquals(expectedPredictions1.toPredictionsStreamDataResponse(), predictions) - - val expectedPredictions2 = buildSomePredictions() - predictionsRepo.onJoin(ApiResult.Ok(expectedPredictions2)) - composeTestRule.awaitIdle() - assertEquals(expectedPredictions2.toPredictionsStreamDataResponse(), predictions) - - unmounted = true - composeTestRule.awaitIdle() + composeTestRule.waitUntil { connectProps == listOf("place-a") } + + composeTestRule.waitUntil { + predictions != null && + predictions == predictionsOnJoin?.toPredictionsStreamDataResponse() + } + + assertEquals(0, disconnectCount) + + stopIds.value = listOf("place-b") + composeTestRule.waitUntil { disconnectCount == 1 } + + composeTestRule.waitUntil { connectProps == listOf("place-b") } } @Test - fun testPredictionsOnClear() = runTest { - var disconnectCalled = false - val stopIds by mutableStateOf(listOf("place-a")) - val mockPredictionsRepository = MockPredictionsRepository(this.backgroundScope) - mockPredictionsRepository.disconnectHook = { disconnectCalled = true } - - val viewModelStore = ViewModelStore() - val viewModelProvider = - ViewModelProvider( - viewModelStore, - object : ViewModelProvider.Factory { - override fun create(modelClass: Class): T { - return PredictionsViewModel( - stopIds, - mockPredictionsRepository, - TimerViewModel(1.seconds) - ) - as T - } - } + fun testDisconnectsOnPause() = runTest { + val lifecycleOwner = TestLifecycleOwner(Lifecycle.State.RESUMED) + + var connectCount = 0 + var disconnectCount = 0 + + val predictionsRepo = + MockPredictionsRepository( + {}, + { stopIds -> connectCount += 1 }, + { disconnectCount += 1 }, + null, + null ) - viewModelProvider.get(PredictionsViewModel::class) - viewModelStore.clear() - assertEquals(true, disconnectCalled) + + var stopIds = mutableStateOf(listOf("place-a")) + var predictions: PredictionsStreamDataResponse? = + PredictionsStreamDataResponse(ObjectCollectionBuilder()) + + composeTestRule.setContent { + CompositionLocalProvider(LocalLifecycleOwner provides lifecycleOwner) { + var stopIds by remember { stopIds } + predictions = subscribeToPredictions(stopIds, predictionsRepo) + } + } + + composeTestRule.waitUntil { connectCount == 1 } + assertEquals(0, disconnectCount) + + composeTestRule.runOnIdle { lifecycleOwner.handleLifecycleEvent(Lifecycle.Event.ON_PAUSE) } + + composeTestRule.waitUntil { disconnectCount == 1 } + assertEquals(1, connectCount) + + composeTestRule.runOnIdle { lifecycleOwner.handleLifecycleEvent(Lifecycle.Event.ON_RESUME) } + + composeTestRule.waitUntil { connectCount == 2 } + assertEquals(1, disconnectCount) } } diff --git a/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToAlerts.kt b/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToAlerts.kt index 096d64059..b72374af8 100644 --- a/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToAlerts.kt +++ b/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToAlerts.kt @@ -2,53 +2,55 @@ package com.mbta.tid.mbta_app.android.state import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState -import androidx.compose.runtime.remember import androidx.lifecycle.LiveData import androidx.lifecycle.MutableLiveData import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider import androidx.lifecycle.asFlow -import com.mbta.tid.mbta_app.android.util.TimerViewModel +import androidx.lifecycle.compose.LifecycleResumeEffect +import androidx.lifecycle.viewmodel.compose.viewModel import com.mbta.tid.mbta_app.model.response.AlertsStreamDataResponse import com.mbta.tid.mbta_app.model.response.ApiResult import com.mbta.tid.mbta_app.repositories.IAlertsRepository -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.flow.collect -import kotlinx.coroutines.launch import okhttp3.internal.notifyAll import org.koin.compose.koinInject class AlertsViewModel( private val alertsRepository: IAlertsRepository, - private val timerViewModel: TimerViewModel ) : ViewModel() { private val _alerts = MutableLiveData(AlertsStreamDataResponse(emptyMap())) val alerts: LiveData = _alerts val alertFlow = alerts.asFlow() - init { - CoroutineScope(Dispatchers.IO).launch { - alertsRepository.connect { - when (it) { - is ApiResult.Ok -> { - _alerts.postValue(it.data) - val oldAlerts = alerts.value?.alerts ?: emptyMap() - if (oldAlerts.isEmpty() && it.data.alerts.isNotEmpty()) - synchronized(alerts) { alerts.notifyAll() } - } - is ApiResult.Error -> { - /* TODO("handle errors") */ - } + fun connect() { + alertsRepository.connect { + when (it) { + is ApiResult.Ok -> { + _alerts.postValue(it.data) + val oldAlerts = alerts.value?.alerts ?: emptyMap() + if (oldAlerts.isEmpty() && it.data.alerts.isNotEmpty()) + synchronized(alerts) { alerts.notifyAll() } + } + is ApiResult.Error -> { + /* TODO("handle errors") */ } } - timerViewModel.timerFlow.collect { synchronized(alerts) { alerts.notifyAll() } } } } + fun disconnect() { + alertsRepository.disconnect() + } + override fun onCleared() { super.onCleared() - alertsRepository.disconnect() + disconnect() + } + + class Factory(private val alertsRepository: IAlertsRepository) : ViewModelProvider.Factory { + override fun create(modelClass: Class): T { + return AlertsViewModel(alertsRepository) as T + } } } @@ -56,7 +58,13 @@ class AlertsViewModel( fun subscribeToAlerts( alertsRepository: IAlertsRepository = koinInject() ): AlertsStreamDataResponse? { - val timerViewModel = remember { TimerViewModel(1.seconds) } - val viewModel = remember { AlertsViewModel(alertsRepository, timerViewModel) } + val viewModel: AlertsViewModel = viewModel(factory = AlertsViewModel.Factory(alertsRepository)) + + LifecycleResumeEffect(key1 = null) { + viewModel.connect() + + onPauseOrDispose { viewModel.disconnect() } + } + return viewModel.alertFlow.collectAsState(initial = null).value } diff --git a/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToPredictions.kt b/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToPredictions.kt index 6c40854a9..8f6dfed94 100644 --- a/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToPredictions.kt +++ b/androidApp/src/main/java/com/mbta/tid/mbta_app/android/state/subscribeToPredictions.kt @@ -1,87 +1,89 @@ package com.mbta.tid.mbta_app.android.state +import android.util.Log import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue -import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.lifecycle.LiveData import androidx.lifecycle.MutableLiveData import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider import androidx.lifecycle.asFlow +import androidx.lifecycle.compose.LifecycleResumeEffect import androidx.lifecycle.viewmodel.compose.viewModel -import com.mbta.tid.mbta_app.android.util.TimerViewModel import com.mbta.tid.mbta_app.model.response.ApiResult import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.repositories.IPredictionsRepository -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.map -import kotlinx.coroutines.launch -import okhttp3.internal.notifyAll import org.koin.compose.koinInject class PredictionsViewModel( - private val stopIds: List, private val predictionsRepository: IPredictionsRepository, - private val timerViewModel: TimerViewModel ) : ViewModel() { private val _predictions: MutableLiveData = MutableLiveData() val predictions: LiveData = _predictions val predictionsFlow = predictions.asFlow().map { it.toPredictionsStreamDataResponse() } - init { - CoroutineScope(Dispatchers.IO).launch { - if (stopIds.size > 0) { - connectToPredictions() + override fun onCleared() { + super.onCleared() + predictionsRepository.disconnect() + } + + fun connect(stopIds: List?) { + + if (!stopIds.isNullOrEmpty()) { + predictionsRepository.connectV2(stopIds, ::handleJoinMessage, ::handlePushMessage) + } + } + + private fun handleJoinMessage(message: ApiResult) { + when (message) { + is ApiResult.Ok -> { + _predictions.postValue(message.data) } - timerViewModel.timerFlow.collect { - synchronized(predictions) { predictions.notifyAll() } + is ApiResult.Error -> { + Log.e( + "PredictionsViewModel", + "Predictions stream failed to join: ${message.message}" + ) } } } - override fun onCleared() { - super.onCleared() + private fun handlePushMessage(message: ApiResult) { + when (message) { + is ApiResult.Ok -> { + _predictions.postValue( + (_predictions.value + ?: PredictionsByStopJoinResponse( + mapOf(message.data.stopId to message.data.predictions), + message.data.trips, + message.data.vehicles + )) + .mergePredictions(message.data) + ) + } + is ApiResult.Error -> { + Log.e( + "PredictionsViewModel", + "Predictions stream failed on message: ${message.message}" + ) + } + } + } + + fun disconnect() { predictionsRepository.disconnect() } - private fun connectToPredictions() { - predictionsRepository.connectV2( - stopIds, - { - when (it) { - is ApiResult.Ok -> { - _predictions.postValue(it.data) - synchronized(predictions) { predictions.notifyAll() } - } - is ApiResult.Error -> { - /* TODO("handle errors") */ - } - } - }, - { - when (it) { - is ApiResult.Ok -> { - _predictions.postValue( - (_predictions.value - ?: PredictionsByStopJoinResponse( - mapOf(it.data.stopId to it.data.predictions), - it.data.trips, - it.data.vehicles - )) - .mergePredictions(it.data) - ) - } - is ApiResult.Error -> { - /* TODO("handle errors") */ - } - } - } - ) + class Factory(private val predictionsRepository: IPredictionsRepository) : + ViewModelProvider.Factory { + override fun create(modelClass: Class): T { + return PredictionsViewModel(predictionsRepository) as T + } } } @@ -90,10 +92,13 @@ fun subscribeToPredictions( stopIds: List?, predictionsRepository: IPredictionsRepository = koinInject() ): PredictionsStreamDataResponse? { - val timerViewModel = remember { TimerViewModel(1.seconds) } val viewModel: PredictionsViewModel = - remember(stopIds) { - PredictionsViewModel(stopIds ?: emptyList(), predictionsRepository, timerViewModel) - } + viewModel(factory = PredictionsViewModel.Factory(predictionsRepository)) + + LifecycleResumeEffect(key1 = stopIds) { + viewModel.connect(stopIds) + + onPauseOrDispose { viewModel.disconnect() } + } return viewModel.predictionsFlow.collectAsState(initial = null).value } diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/AlertsRepository.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/AlertsRepository.kt index 0535e01d6..f96e875e4 100644 --- a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/AlertsRepository.kt +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/AlertsRepository.kt @@ -82,14 +82,18 @@ class AlertsRepository(private val socket: PhoenixSocket) : IAlertsRepository, K class MockAlertsRepository @DefaultArgumentInterop.Enabled -constructor(private val response: AlertsStreamDataResponse = AlertsStreamDataResponse(emptyMap())) : - IAlertsRepository { +constructor( + private val response: AlertsStreamDataResponse = AlertsStreamDataResponse(emptyMap()), + private val onConnect: () -> Unit = {}, + private val onDisconnect: () -> Unit = {} +) : IAlertsRepository { override fun connect(onReceive: (ApiResult) -> Unit) { + onConnect() onReceive(ApiResult.Ok(response)) } override fun disconnect() { - /* no-op */ + onDisconnect() } }