diff --git a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitPageTest.kt b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitPageTest.kt index df29d386d..991cdee5a 100644 --- a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitPageTest.kt +++ b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitPageTest.kt @@ -23,6 +23,8 @@ import com.mbta.tid.mbta_app.model.SocketError import com.mbta.tid.mbta_app.model.response.AlertsStreamDataResponse import com.mbta.tid.mbta_app.model.response.GlobalResponse import com.mbta.tid.mbta_app.model.response.NearbyResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.repositories.INearbyRepository import com.mbta.tid.mbta_app.repositories.IPinnedRoutesRepository @@ -185,6 +187,15 @@ class NearbyTransitPageTest : KoinTest { onReceive(Outcome(PredictionsStreamDataResponse(builder), null)) } + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: + (Outcome) -> Unit + ) { + /* no-op */ + } + override fun disconnect() { /* no-op */ } diff --git a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitViewTest.kt b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitViewTest.kt index 12816c9bb..cd166d9ad 100644 --- a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitViewTest.kt +++ b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/nearbyTransit/NearbyTransitViewTest.kt @@ -12,6 +12,8 @@ import com.mbta.tid.mbta_app.model.RouteType import com.mbta.tid.mbta_app.model.SocketError import com.mbta.tid.mbta_app.model.response.GlobalResponse import com.mbta.tid.mbta_app.model.response.NearbyResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.repositories.INearbyRepository import com.mbta.tid.mbta_app.repositories.IPinnedRoutesRepository @@ -179,6 +181,15 @@ class NearbyTransitViewTest : KoinTest { onReceive(Outcome(PredictionsStreamDataResponse(builder), null)) } + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: + (Outcome) -> Unit + ) { + /* no-op */ + } + override fun disconnect() { /* no-op */ } diff --git a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/stopDetails/StopDetailsViewTest.kt b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/stopDetails/StopDetailsViewTest.kt index ad63564b4..1d23dc9a0 100644 --- a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/stopDetails/StopDetailsViewTest.kt +++ b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/stopDetails/StopDetailsViewTest.kt @@ -18,6 +18,8 @@ import com.mbta.tid.mbta_app.model.StopDetailsFilter import com.mbta.tid.mbta_app.model.UpcomingTrip import com.mbta.tid.mbta_app.model.response.GlobalResponse import com.mbta.tid.mbta_app.model.response.NearbyResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.repositories.IGlobalRepository import com.mbta.tid.mbta_app.repositories.INearbyRepository @@ -125,6 +127,15 @@ class StopDetailsViewTest { onReceive(Outcome(PredictionsStreamDataResponse(builder), null)) } + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: + (Outcome) -> Unit + ) { + /* no-op */ + } + override fun disconnect() { /* no-op */ } diff --git a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/util/SubscribeToPredictionsTest.kt b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/util/SubscribeToPredictionsTest.kt index b605a246a..191d73af4 100644 --- a/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/util/SubscribeToPredictionsTest.kt +++ b/androidApp/src/androidTest/java/com/mbta/tid/mbta_app/android/util/SubscribeToPredictionsTest.kt @@ -7,6 +7,8 @@ import androidx.compose.ui.test.junit4.createComposeRule import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder import com.mbta.tid.mbta_app.model.Outcome import com.mbta.tid.mbta_app.model.SocketError +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.repositories.IPredictionsRepository import kotlinx.coroutines.channels.Channel @@ -48,6 +50,14 @@ class SubscribeToPredictionsTest { this.onReceive = onReceive } + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: (Outcome) -> Unit + ) { + /* no-op */ + } + override fun disconnect() { check(isConnected) { "called disconnect when not connected" } isConnected = false diff --git a/androidApp/src/main/java/com/mbta/tid/mbta_app/android/util/subscribeToPredictions.kt b/androidApp/src/main/java/com/mbta/tid/mbta_app/android/util/subscribeToPredictions.kt index 275517c3f..59554ff0f 100644 --- a/androidApp/src/main/java/com/mbta/tid/mbta_app/android/util/subscribeToPredictions.kt +++ b/androidApp/src/main/java/com/mbta/tid/mbta_app/android/util/subscribeToPredictions.kt @@ -10,7 +10,6 @@ import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.repositories.IPredictionsRepository import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.cancel import kotlinx.coroutines.launch import org.koin.compose.koinInject diff --git a/iosApp/iosApp/Localizable.xcstrings b/iosApp/iosApp/Localizable.xcstrings index a00b08505..0c4928206 100644 --- a/iosApp/iosApp/Localizable.xcstrings +++ b/iosApp/iosApp/Localizable.xcstrings @@ -421,6 +421,9 @@ }, "Updated: %@" : { "comment" : "Interpolated value is a timestamp for when displayed alert details were last changed" + }, + "Using Predictions Channel V2" : { + }, "Weather" : { "comment" : "Possible alert cause" diff --git a/iosApp/iosApp/Pages/NearbyTransit/NearbyTransitView.swift b/iosApp/iosApp/Pages/NearbyTransit/NearbyTransitView.swift index 8d3a02742..3fcf2a960 100644 --- a/iosApp/iosApp/Pages/NearbyTransit/NearbyTransitView.swift +++ b/iosApp/iosApp/Pages/NearbyTransit/NearbyTransitView.swift @@ -20,6 +20,7 @@ struct NearbyTransitView: View { var pinnedRouteRepository = RepositoryDI().pinnedRoutes @State var predictionsRepository = RepositoryDI().predictions var schedulesRepository = RepositoryDI().schedules + var settingsRepository = RepositoryDI().settings var getNearby: (GlobalResponse, CLLocationCoordinate2D) -> Void @Binding var state: NearbyViewModel.NearbyTransitState @Binding var location: CLLocationCoordinate2D? @@ -30,8 +31,10 @@ struct NearbyTransitView: View { @State var nearbyWithRealtimeInfo: [StopsAssociated]? @State var now = Date.now @State var pinnedRoutes: Set = [] + @State var predictionsByStop: PredictionsByStopJoinResponse? @State var predictions: PredictionsStreamDataResponse? @State var predictionsError: SocketError? + @State var predictionsV2Enabled = false let timer = Timer.publish(every: 5, on: .main, in: .common).autoconnect() let inspection = Inspection() @@ -48,7 +51,10 @@ struct NearbyTransitView: View { .onAppear { getGlobal() getNearby(location: location, globalData: globalData) - joinPredictions(state.nearbyByRouteAndStop?.stopIds()) + Task { + await getPredictionsFeatureFlag() + joinPredictions(state.nearbyByRouteAndStop?.stopIds()) + } updateNearbyRoutes() updatePinnedRoutes() getSchedule() @@ -69,6 +75,14 @@ struct NearbyTransitView: View { .onChange(of: scheduleResponse) { response in updateNearbyRoutes(scheduleResponse: response) } + .onChange(of: predictionsByStop) { newPredictionsByStop in + if let newPredictionsByStop { + let condensedPredictions = newPredictionsByStop.toPredictionsStreamDataResponse() + updateNearbyRoutes(predictions: condensedPredictions) + } else { + updateNearbyRoutes(predictions: nil) + } + } .onChange(of: predictions) { predictions in updateNearbyRoutes(predictions: predictions) } @@ -105,6 +119,9 @@ struct NearbyTransitView: View { } else { ScrollViewReader { proxy in ScrollView { + if predictionsV2Enabled { + Text("Using Predictions Channel V2") + } LazyVStack { ForEach(transit, id: \.id) { nearbyTransit in switch onEnum(of: nearbyTransit) { @@ -142,7 +159,9 @@ struct NearbyTransitView: View { private func errorCard(_ errorText: String) -> some View { IconCard(iconName: "network.slash", details: Text(errorText)) - .refreshable(state.loading) { getNearby(location: location, globalData: globalData) } + .refreshable(state.loading) { + getNearby(location: location, globalData: globalData) + } } var didAppear: ((Self) -> Void)? @@ -171,18 +190,62 @@ struct NearbyTransitView: View { } } + func getPredictionsFeatureFlag() async { + do { + let settings = try await settingsRepository.getSettings() + predictionsV2Enabled = settings.first(where: { $0.key == .predictionsV2Channel })?.isOn ?? false + } catch {} + } + func joinPredictions(_ stopIds: Set?) { guard let stopIds else { return } - predictionsRepository.connect(stopIds: Array(stopIds)) { outcome in + if predictionsV2Enabled { + joinPredictionsV2(stopIds: stopIds) + } else { + predictionsRepository.connect(stopIds: Array(stopIds)) { outcome in + DispatchQueue.main.async { + if let data = outcome.data { + predictions = data + predictionsError = nil + } else if let error = outcome.error { + predictionsError = error.toSwiftEnum() + } + } + } + } + } + + func joinPredictionsV2(stopIds: Set) { + predictionsRepository.connectV2(stopIds: Array(stopIds), onJoin: { outcome in DispatchQueue.main.async { if let data = outcome.data { - predictions = data + predictionsByStop = data predictionsError = nil } else if let error = outcome.error { predictionsError = error.toSwiftEnum() } } - } + }, onMessage: { outcome in + DispatchQueue.main.async { + if let data = outcome.data { + if let existingPredictionsByStop = predictionsByStop { + predictionsByStop = existingPredictionsByStop.mergePredictions(updatedPredictions: data) + predictionsError = nil + } else { + predictionsByStop = PredictionsByStopJoinResponse( + predictionsByStop: [data.stopId: data.predictions], + trips: data.trips, + vehicles: data.vehicles + ) + predictionsError = nil + } + + } else if let error = outcome.error { + predictionsError = error.toSwiftEnum() + } + } + + }) } func leavePredictions() { @@ -223,9 +286,13 @@ struct NearbyTransitView: View { alerts: AlertsStreamDataResponse? = nil, pinnedRoutes: Set? = nil ) { + let fallbackPredictions = if let predictionsByStop { + predictionsByStop.toPredictionsStreamDataResponse() + } else { self.predictions } + nearbyWithRealtimeInfo = withRealtimeInfo( schedules: scheduleResponse ?? self.scheduleResponse, - predictions: predictions ?? self.predictions, + predictions: predictions ?? fallbackPredictions, alerts: alerts ?? nearbyVM.alerts, filterAtTime: now.toKotlinInstant(), pinnedRoutes: pinnedRoutes ?? self.pinnedRoutes diff --git a/iosApp/iosApp/Pages/Settings/Setting+Convenience.swift b/iosApp/iosApp/Pages/Settings/Setting+Convenience.swift index 82c439d35..34ce4d698 100644 --- a/iosApp/iosApp/Pages/Settings/Setting+Convenience.swift +++ b/iosApp/iosApp/Pages/Settings/Setting+Convenience.swift @@ -22,6 +22,8 @@ extension Setting: Identifiable { "Search - Route Results" case .map: "Map Debug" + case .predictionsV2Channel: + "Predictions V2 Channel" } } @@ -33,6 +35,8 @@ extension Setting: Identifiable { "point.topleft.down.to.point.bottomright.curvepath.fill" case .map: "location.magnifyingglass" + case .predictionsV2Channel: + "magnifyingglass" } } @@ -42,6 +46,8 @@ extension Setting: Identifiable { .featureFlags case .searchRouteResults: .featureFlags + case .predictionsV2Channel: + .featureFlags case .map: .debug } diff --git a/iosApp/iosApp/Pages/StopDetails/StopDetailsPage.swift b/iosApp/iosApp/Pages/StopDetails/StopDetailsPage.swift index b90fd6f02..8c70fff92 100644 --- a/iosApp/iosApp/Pages/StopDetails/StopDetailsPage.swift +++ b/iosApp/iosApp/Pages/StopDetails/StopDetailsPage.swift @@ -16,6 +16,7 @@ struct StopDetailsPage: View { @State var globalResponse: GlobalResponse? @ObservedObject var viewportProvider: ViewportProvider let schedulesRepository: ISchedulesRepository + var settingsRepository = RepositoryDI().settings @State var schedulesResponse: ScheduleResponse? var pinnedRouteRepository = RepositoryDI().pinnedRoutes var togglePinnedUsecase = UsecaseDI().toggledPinnedRouteUsecase @@ -27,55 +28,72 @@ struct StopDetailsPage: View { @ObservedObject var nearbyVM: NearbyViewModel @State var pinnedRoutes: Set = [] @State var predictions: PredictionsStreamDataResponse? + @State var predictionsByStop: PredictionsByStopJoinResponse? + @State var predictionsV2Enabled = false let inspection = Inspection() let timer = Timer.publish(every: 5, on: .main, in: .common).autoconnect() + var didAppear: ((Self) -> Void)? + init( globalRepository: IGlobalRepository = RepositoryDI().global, schedulesRepository: ISchedulesRepository = RepositoryDI().schedules, + settingsRepository: ISettingsRepository = RepositoryDI().settings, predictionsRepository: IPredictionsRepository = RepositoryDI().predictions, viewportProvider: ViewportProvider, stop: Stop, filter: Binding, - nearbyVM: NearbyViewModel + nearbyVM: NearbyViewModel, + predictionsV2Enabled: Bool = false ) { self.globalRepository = globalRepository self.schedulesRepository = schedulesRepository + self.settingsRepository = settingsRepository self.predictionsRepository = predictionsRepository self.viewportProvider = viewportProvider self.stop = stop _filter = filter self.nearbyVM = nearbyVM + self.predictionsV2Enabled = predictionsV2Enabled } var body: some View { - StopDetailsView( - stop: stop, - filter: $filter, - nearbyVM: nearbyVM, - pinnedRoutes: pinnedRoutes, - togglePinnedRoute: togglePinnedRoute - ) - .onAppear { - loadGlobalData() - changeStop(stop) - loadPinnedRoutes() - } - .onChange(of: stop) { nextStop in changeStop(nextStop) } - .onChange(of: globalResponse) { _ in updateDepartures() } - .onChange(of: pinnedRoutes) { _ in updateDepartures() } - .onChange(of: predictions) { _ in updateDepartures() } - .onChange(of: schedulesResponse) { _ in updateDepartures() } - .onReceive(inspection.notice) { inspection.visit(self, $0) } - .onReceive(timer) { input in - now = input - updateDepartures() + VStack { + if predictionsV2Enabled { + Text("Using Predictions Channel V2") + } + StopDetailsView( + stop: stop, + filter: $filter, + nearbyVM: nearbyVM, + pinnedRoutes: pinnedRoutes, + togglePinnedRoute: togglePinnedRoute + ) + .onAppear { + loadGlobalData() + changeStop(stop) + loadPinnedRoutes() + didAppear?(self) + } + .onChange(of: stop) { nextStop in changeStop(nextStop) } + .onChange(of: globalResponse) { _ in updateDepartures() } + .onChange(of: pinnedRoutes) { _ in updateDepartures() } + .onChange(of: predictionsByStop) { newPredictionsByStop in + updateDepartures(stop, newPredictionsByStop, predictions) + } + .onChange(of: predictions) { _ in updateDepartures() } + .onChange(of: schedulesResponse) { _ in updateDepartures() } + .onReceive(inspection.notice) { inspection.visit(self, $0) } + .onReceive(timer) { input in + now = input + updateDepartures() + } + .onDisappear { leavePredictions() } + .withScenePhaseHandlers(onActive: { joinPredictions(stop) }, + onInactive: leavePredictions, + onBackground: leavePredictions) } - .onDisappear { leavePredictions() } - .withScenePhaseHandlers(onActive: { joinPredictions(stop) }, - onInactive: leavePredictions, - onBackground: leavePredictions) } func loadGlobalData() { @@ -126,30 +144,75 @@ struct StopDetailsPage: View { } func joinPredictions(_ stop: Stop) { - predictionsRepository.connect(stopIds: [stop.id]) { outcome in - DispatchQueue.main.async { - predictions = if let data = outcome.data { - data - } else { - nil + Task { + let settings = try await settingsRepository.getSettings() + var isEnabled = settings.first(where: { $0.key == .predictionsV2Channel })?.isOn ?? false + predictionsV2Enabled = isEnabled + if isEnabled { + joinPredictionsV2(stopIds: [stop.id]) + } else { + predictionsRepository.connect(stopIds: [stop.id]) { outcome in + DispatchQueue.main.async { + predictions = if let data = outcome.data { + data + } else { + nil + } + } } } } } + func joinPredictionsV2(stopIds: Set) { + predictionsRepository.connectV2(stopIds: Array(stopIds), onJoin: { outcome in + DispatchQueue.main.async { + if let data = outcome.data { + predictionsByStop = data + } + } + }, onMessage: { outcome in + DispatchQueue.main.async { + if let data = outcome.data { + if let existingPredictionsByStop = predictionsByStop { + predictionsByStop = existingPredictionsByStop.mergePredictions(updatedPredictions: data) + } else { + predictionsByStop = PredictionsByStopJoinResponse( + predictionsByStop: [data.stopId: data.predictions], + trips: data.trips, + vehicles: data.vehicles + ) + } + } + } + + }) + } + func leavePredictions() { predictionsRepository.disconnect() } - func updateDepartures(_ stop: Stop? = nil) { + func updateDepartures( + _ stop: Stop? = nil, + _ predictionsByStop: PredictionsByStopJoinResponse? = nil, + _ predictions: PredictionsStreamDataResponse? = nil + ) { let stop = stop ?? self.stop + let predictionsByStop = predictionsByStop ?? self.predictionsByStop + + let targetPredictions = if let predictionsByStop { + predictionsByStop.toPredictionsStreamDataResponse() + } else { + predictions ?? self.predictions + } let newDepartures: StopDetailsDepartures? = if let globalResponse { StopDetailsDepartures( stop: stop, global: globalResponse, schedules: schedulesResponse, - predictions: predictions, + predictions: targetPredictions, alerts: nearbyVM.alerts, pinnedRoutes: pinnedRoutes, filterAtTime: now.toKotlinInstant() diff --git a/iosApp/iosAppTests/Pages/StopDetails/StopDetailsPageTests.swift b/iosApp/iosAppTests/Pages/StopDetails/StopDetailsPageTests.swift index 10fdcf0c3..5f4885e12 100644 --- a/iosApp/iosAppTests/Pages/StopDetails/StopDetailsPageTests.swift +++ b/iosApp/iosAppTests/Pages/StopDetails/StopDetailsPageTests.swift @@ -232,31 +232,52 @@ final class StopDetailsPageTests: XCTestCase { let leaveExpectation = expectation(description: "leaves predictions") - class FakePredictionsRepo: IPredictionsRepository { - let joinExpectation: XCTestExpectation - let leaveExpectation: XCTestExpectation + let predictionsRepo = MockPredictionsRepository(onConnect: { joinExpectation.fulfill() }, + onConnectV2: {}, + onDisconnect: { leaveExpectation.fulfill() }, + connectOutcome: nil, + connectV2Outcome: nil) + let sut = StopDetailsPage( + schedulesRepository: MockScheduleRepository(), + predictionsRepository: predictionsRepo, + viewportProvider: viewportProvider, + stop: stop, + filter: filter, + nearbyVM: .init() + ) - init(joinExpectation: XCTestExpectation, leaveExpectation: XCTestExpectation) { - self.joinExpectation = joinExpectation - self.leaveExpectation = leaveExpectation - } + ViewHosting.host(view: sut) - func connect( - stopIds _: [String], - onReceive _: @escaping (Outcome) - -> Void - ) { - joinExpectation.fulfill() - } + try sut.inspect().find(StopDetailsView.self).callOnChange(newValue: ScenePhase.background) - func disconnect() { - leaveExpectation.fulfill() - } - } + wait(for: [leaveExpectation], timeout: 1) + + try sut.inspect().find(StopDetailsView.self).callOnChange(newValue: ScenePhase.active) + + wait(for: [joinExpectation], timeout: 1) + } + + func testJoinsPredictionsV2WhenEnabled() throws { + let objects = ObjectCollectionBuilder() + let route = objects.route() + let stop = objects.stop { _ in } + + let viewportProvider: ViewportProvider = .init(viewport: .followPuck(zoom: 1)) + let filter: Binding = .constant(.init(routeId: route.id, directionId: 0)) + let joinExpectation = expectation(description: "joins predictions") + joinExpectation.expectedFulfillmentCount = 2 + joinExpectation.assertForOverFulfill = true - let predictionsRepo = FakePredictionsRepo(joinExpectation: joinExpectation, leaveExpectation: leaveExpectation) + let leaveExpectation = expectation(description: "leaves predictions") + + let predictionsRepo = MockPredictionsRepository(onConnect: {}, + onConnectV2: { joinExpectation.fulfill() }, + onDisconnect: { leaveExpectation.fulfill() }, + connectOutcome: nil, + connectV2Outcome: nil) let sut = StopDetailsPage( schedulesRepository: MockScheduleRepository(), + settingsRepository: MockSettingsRepository(settings: [.init(key: .predictionsV2Channel, isOn: true)]), predictionsRepository: predictionsRepo, viewportProvider: viewportProvider, stop: stop, @@ -275,6 +296,49 @@ final class StopDetailsPageTests: XCTestCase { wait(for: [joinExpectation], timeout: 1) } + func testUpdatesDeparturesOnV2PredictionsChange() throws { + let objects = ObjectCollectionBuilder() + let route = objects.route() + let stop = objects.stop { _ in } + let prediction = objects.prediction { _ in } + let pattern = objects.routePattern(route: route) { _ in } + let trip = objects.trip { trip in + trip.id = prediction.tripId + trip.stopIds = [stop.id] + } + + let viewportProvider: ViewportProvider = .init(viewport: .followPuck(zoom: 1)) + let filter: Binding = .constant(.init(routeId: route.id, directionId: 0)) + + let nearbyVM: NearbyViewModel = .init() + + let globalDataLoaded = PassthroughSubject() + + let predictionsRepo = MockPredictionsRepository() + var sut = StopDetailsPage( + globalRepository: MockGlobalRepository(response: .init(objects: objects, + patternIdsByStop: [stop.id: [pattern.id]]), + onGet: { globalDataLoaded.send() }), + schedulesRepository: MockScheduleRepository(), + settingsRepository: MockSettingsRepository(settings: [.init(key: .predictionsV2Channel, isOn: true)]), + predictionsRepository: predictionsRepo, + viewportProvider: viewportProvider, + stop: stop, + filter: filter, + nearbyVM: nearbyVM + ) + + XCTAssertNil(nearbyVM.departures) + let hasAppeared = sut.inspection.inspect(onReceive: globalDataLoaded, after: 1) { view in + try view.find(StopDetailsView.self) + .callOnChange(newValue: PredictionsByStopJoinResponse(objects: objects)) + XCTAssertNotNil(nearbyVM.departures) + } + ViewHosting.host(view: sut) + + wait(for: [hasAppeared], timeout: 5) + } + func testAppliesFilterAutomatically() throws { let objects = ObjectCollectionBuilder() let route = objects.route() diff --git a/iosApp/iosAppTests/Views/NearbyTransitViewTests.swift b/iosApp/iosAppTests/Views/NearbyTransitViewTests.swift index 7496fae70..7e97dff0f 100644 --- a/iosApp/iosAppTests/Views/NearbyTransitViewTests.swift +++ b/iosApp/iosAppTests/Views/NearbyTransitViewTests.swift @@ -370,6 +370,95 @@ final class NearbyTransitViewTests: XCTestCase { wait(for: [exp], timeout: 1) } + @MainActor func testWithPredictionsV2() throws { + NSTimeZone.default = TimeZone(identifier: "America/New_York")! + let now = Date.now + let distantMinutes: Double = 10 + let distantInstant = now.addingTimeInterval(distantMinutes * 60).toKotlinInstant() + let objects = ObjectCollectionBuilder() + let route = objects.route() + + let rp1 = objects.routePattern(route: route) { routePattern in + routePattern.id = "52-5-0" + routePattern.representativeTrip { representativeTrip in + representativeTrip.headsign = "Dedham Mall" + representativeTrip.routePatternId = routePattern.id + } + } + let rp2 = objects.routePattern(route: route) { routePattern in + routePattern.id = "52-4-1" + routePattern.representativeTrip { representativeTrip in + representativeTrip.headsign = "Watertown Yard" + representativeTrip.routePatternId = routePattern.id + } + } + objects.prediction { prediction in + prediction.arrivalTime = now.addingTimeInterval(distantMinutes * 60).toKotlinInstant() + prediction.departureTime = now.addingTimeInterval((distantMinutes + 2) * 60).toKotlinInstant() + prediction.routeId = "52" + prediction.stopId = "8552" + prediction.tripId = objects.trip(routePattern: rp1).id + } + objects.prediction { prediction in + prediction.arrivalTime = now.addingTimeInterval((distantMinutes + 1) * 60).toKotlinInstant() + prediction.departureTime = now.addingTimeInterval((distantMinutes + 5) * 60).toKotlinInstant() + prediction.status = "Overridden" + prediction.routeId = "52" + prediction.stopId = "8552" + prediction.tripId = objects.trip(routePattern: rp1).id + } + objects.prediction { prediction in + prediction.arrivalTime = now.addingTimeInterval(1 * 60 + 1).toKotlinInstant() + prediction.departureTime = now.addingTimeInterval(2 * 60).toKotlinInstant() + prediction.routeId = "52" + prediction.stopId = "84791" + prediction.tripId = objects.trip(routePattern: rp2).id + } + objects.prediction { prediction in + prediction.departureTime = distantInstant + prediction.routeId = "52" + prediction.stopId = "84791" + prediction.tripId = objects.trip(routePattern: rp2).id + } + let predictionsByStop: PredictionsByStopJoinResponse = .init(objects: objects) + + var sut = NearbyTransitView( + togglePinnedUsecase: TogglePinnedRouteUsecase(repository: pinnedRoutesRepository), + pinnedRouteRepository: pinnedRoutesRepository, + predictionsRepository: MockPredictionsRepository(), + schedulesRepository: MockScheduleRepository(), + settingsRepository: MockSettingsRepository(settings: [.init(key: .predictionsV2Channel, isOn: true)]), + getNearby: { _, _ in }, + state: .constant(route52State), + location: .constant(CLLocationCoordinate2D(latitude: 12.34, longitude: -56.78)), + nearbyVM: .init(), + now: now + ) + + let exp = sut.on(\.didAppear) { view in + try view.vStack().callOnChange(newValue: predictionsByStop) + let stops = view.findAll(NearbyStopView.self) + XCTAssertNotNil(try stops[0].find(text: "Charles River Loop") + .parent().parent().find(text: "No real-time data")) + + XCTAssertNotNil(try stops[0].find(text: "Dedham Mall") + .parent().parent().find(text: "10 min")) + XCTAssertNotNil(try stops[0].find(text: "Dedham Mall") + .parent().parent().find(text: "Overridden")) + + XCTAssertNotNil(try stops[1].find(text: "Watertown Yard") + .parent().parent().find(text: "1 min")) + let expectedMinutes = distantMinutes + let expectedState = UpcomingTripView.State.some(.Minutes(minutes: Int32(expectedMinutes))) + XCTAssert(try !stops[1].find(text: "Watertown Yard").parent().parent() + .findAll(UpcomingTripView.self, where: { sut in + try sut.actualView().prediction == expectedState + }).isEmpty) + } + ViewHosting.host(view: sut) + wait(for: [exp], timeout: 1) + } + @MainActor func testLineGrouping() throws { NSTimeZone.default = TimeZone(identifier: "America/New_York")! @@ -507,6 +596,12 @@ final class NearbyTransitViewTests: XCTestCase { } } + func connectV2(stopIds _: [String], + onJoin _: @escaping (Outcome) -> Void, + onMessage _: @escaping (Outcome) -> Void) { + /* no-op */ + } + func disconnect() { /* no-op */ } } @@ -591,35 +686,53 @@ final class NearbyTransitViewTests: XCTestCase { wait(for: [exp], timeout: 1) } - func testLeavesChannelWhenBackgrounded() throws { - let joinExpectation = expectation(description: "joins predictions") - let leaveExpectation = expectation(description: "leaves predictions") - - class FakePredictionsRepository: IPredictionsRepository { - let joinExpectation: XCTestExpectation - let leaveExpectation: XCTestExpectation + func testRendersUpdatedPredictionsV2() throws { + NSTimeZone.default = TimeZone(identifier: "America/New_York")! + var sut = NearbyTransitView( + togglePinnedUsecase: TogglePinnedRouteUsecase(repository: pinnedRoutesRepository), + pinnedRouteRepository: pinnedRoutesRepository, + predictionsRepository: MockPredictionsRepository(), + schedulesRepository: MockScheduleRepository(), + settingsRepository: MockSettingsRepository(settings: [.init(key: .predictionsV2Channel, isOn: true)]), + getNearby: { _, _ in }, + state: .constant(route52State), + location: .constant(CLLocationCoordinate2D(latitude: 12.34, longitude: -56.78)), + nearbyVM: .init() + ) - init(joinExpectation: XCTestExpectation, leaveExpectation: XCTestExpectation) { - self.joinExpectation = joinExpectation - self.leaveExpectation = leaveExpectation + func prediction(minutesAway: Double) -> PredictionsByStopJoinResponse { + let objects = ObjectCollectionBuilder() + let trip = objects.trip { trip in + trip.headsign = "Dedham Mall" + trip.routePatternId = "52-5-0" } - - func connect( - stopIds _: [String], - onReceive _: @escaping (Outcome) - -> Void - ) { - joinExpectation.fulfill() + objects.prediction { prediction in + prediction.departureTime = Date.now.addingTimeInterval(minutesAway * 60).toKotlinInstant() + prediction.routeId = "52" + prediction.stopId = "8552" + prediction.tripId = trip.id } + return PredictionsByStopJoinResponse(objects: objects) + } - func disconnect() { - leaveExpectation.fulfill() - } + let exp = sut.on(\.didAppear) { view in + try view.vStack().callOnChange(newValue: prediction(minutesAway: 2)) + XCTAssertNotNil(try view.vStack().find(text: "2 min")) + try view.vStack().callOnChange(newValue: prediction(minutesAway: 3)) + XCTAssertNotNil(try view.vStack().find(text: "3 min")) } + ViewHosting.host(view: sut) + wait(for: [exp], timeout: 1) + } - let predictionsRepo = FakePredictionsRepository( - joinExpectation: joinExpectation, - leaveExpectation: leaveExpectation + func testLeavesChannelWhenBackgrounded() throws { + let joinExpectation = expectation(description: "joins predictions") + let leaveExpectation = expectation(description: "leaves predictions") + + let predictionsRepo = MockPredictionsRepository( + onConnect: { joinExpectation.fulfill() }, + onConnectV2: {}, + onDisconnect: { leaveExpectation.fulfill() } ) let sut = NearbyTransitView( togglePinnedUsecase: TogglePinnedRouteUsecase(repository: pinnedRoutesRepository), @@ -644,32 +757,9 @@ final class NearbyTransitViewTests: XCTestCase { let joinExpectation = expectation(description: "joins predictions") let leaveExpectation = expectation(description: "leaves predictions") - class FakePredictionsRepository: IPredictionsRepository { - let joinExpectation: XCTestExpectation - let leaveExpectation: XCTestExpectation - - init(joinExpectation: XCTestExpectation, leaveExpectation: XCTestExpectation) { - self.joinExpectation = joinExpectation - self.leaveExpectation = leaveExpectation - } + let predictionsRepo = MockPredictionsRepository(onConnect: { joinExpectation.fulfill() }, + onConnectV2: {}, onDisconnect: { leaveExpectation.fulfill() }) - func connect( - stopIds _: [String], - onReceive _: @escaping (Outcome) - -> Void - ) { - joinExpectation.fulfill() - } - - func disconnect() { - leaveExpectation.fulfill() - } - } - - let predictionsRepo = FakePredictionsRepository( - joinExpectation: joinExpectation, - leaveExpectation: leaveExpectation - ) let sut = NearbyTransitView( togglePinnedUsecase: TogglePinnedRouteUsecase(repository: pinnedRoutesRepository), pinnedRouteRepository: pinnedRoutesRepository, @@ -696,31 +786,10 @@ final class NearbyTransitViewTests: XCTestCase { let leaveExpectation = expectation(description: "leaves predictions") - class FakePredictionsRepository: IPredictionsRepository { - let joinExpectation: XCTestExpectation - let leaveExpectation: XCTestExpectation - - init(joinExpectation: XCTestExpectation, leaveExpectation: XCTestExpectation) { - self.joinExpectation = joinExpectation - self.leaveExpectation = leaveExpectation - } - - func connect( - stopIds _: [String], - onReceive _: @escaping (Outcome) - -> Void - ) { - joinExpectation.fulfill() - } - - func disconnect() { - leaveExpectation.fulfill() - } - } - - let predictionsRepo = FakePredictionsRepository( - joinExpectation: joinExpectation, - leaveExpectation: leaveExpectation + let predictionsRepo = MockPredictionsRepository( + onConnect: { joinExpectation.fulfill() }, + onConnectV2: {}, + onDisconnect: { leaveExpectation.fulfill() } ) let sut = NearbyTransitView( togglePinnedUsecase: TogglePinnedRouteUsecase(repository: pinnedRoutesRepository), @@ -789,27 +858,14 @@ final class NearbyTransitViewTests: XCTestCase { loadedLocation: CLLocationCoordinate2D(latitude: 12.34, longitude: -56.78), nearbyByRouteAndStop: NearbyStaticData(data: []) ) - class FakePredictionsRepository: IPredictionsRepository { - let callback: (() -> Void)? - init(callback: @escaping (() -> Void)) { - self.callback = callback - } - - func connect( - stopIds _: [String], - onReceive: @escaping (Outcome) - -> Void - ) { - callback?() - onReceive(Outcome(data: nil, error: SocketError.unknown.toKotlinEnum())) - } - - func disconnect() { /* no-op */ } - } - let predictionsRepo = FakePredictionsRepository { - predictionsErroredPublisher.send(true) - } + let predictionsRepo = MockPredictionsRepository(onConnect: { predictionsErroredPublisher.send(true) }, + onConnectV2: {}, + onDisconnect: {}, + connectOutcome: + Outcome(data: nil, + error: SocketError.unknown.toKotlinEnum()), + connectV2Outcome: nil) let sut = NearbyTransitView( togglePinnedUsecase: TogglePinnedRouteUsecase(repository: pinnedRoutesRepository), pinnedRouteRepository: pinnedRoutesRepository, diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/endToEnd/EndToEndRepositories.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/endToEnd/EndToEndRepositories.kt index 31c6dac61..500916f59 100644 --- a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/endToEnd/EndToEndRepositories.kt +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/endToEnd/EndToEndRepositories.kt @@ -11,6 +11,8 @@ import com.mbta.tid.mbta_app.model.TripShape import com.mbta.tid.mbta_app.model.response.ApiResult import com.mbta.tid.mbta_app.model.response.GlobalResponse import com.mbta.tid.mbta_app.model.response.NearbyResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.model.response.ScheduleResponse import com.mbta.tid.mbta_app.model.response.StopMapResponse @@ -118,6 +120,14 @@ fun endToEndModule(): Module { onReceive(Outcome(PredictionsStreamDataResponse(objects), null)) } + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: (Outcome) -> Unit + ) { + onJoin(Outcome(PredictionsByStopJoinResponse(objects), null)) + } + override fun disconnect() {} } } diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopJoinResponse.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopJoinResponse.kt new file mode 100644 index 000000000..661396aee --- /dev/null +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopJoinResponse.kt @@ -0,0 +1,68 @@ +package com.mbta.tid.mbta_app.model.response + +import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder +import com.mbta.tid.mbta_app.model.Prediction +import com.mbta.tid.mbta_app.model.Trip +import com.mbta.tid.mbta_app.model.Vehicle +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class PredictionsByStopJoinResponse( + @SerialName("predictions_by_stop") val predictionsByStop: Map>, + val trips: Map, + val vehicles: Map +) { + + constructor( + objects: ObjectCollectionBuilder + ) : this( + objects.predictions.values + .groupBy { it.stopId } + .mapValues { predictions -> predictions.value.associateBy { it.id } }, + objects.trips, + objects.vehicles + ) + + /** + * Merge the latest predictions for a single stop into the predictions for all stops. Removes + * vehicles & trips that are no longer referenced in any predictions + */ + fun mergePredictions( + updatedPredictions: PredictionsByStopMessageResponse + ): PredictionsByStopJoinResponse { + + val updatedPredictionsByStop: Map> = + predictionsByStop.plus(Pair(updatedPredictions.stopId, updatedPredictions.predictions)) + + val usedTrips = mutableSetOf() + val usedVehicles = mutableSetOf() + val predictions = updatedPredictionsByStop.flatMap { it.value.values } + predictions.forEach { + usedTrips.add(it.tripId) + if (it.vehicleId != null) { + usedVehicles.add(it.vehicleId) + } + } + + val updatedTrips = trips.plus(updatedPredictions.trips).filterKeys { it in usedTrips } + val updatedVehicles = + vehicles.plus(updatedPredictions.vehicles).filterKeys { it in usedVehicles } + + return PredictionsByStopJoinResponse( + predictionsByStop = updatedPredictionsByStop, + trips = updatedTrips, + vehicles = updatedVehicles + ) + } + + /** Flattens the `predictionsByStop` field into a single map of predictions by id */ + fun toPredictionsStreamDataResponse(): PredictionsStreamDataResponse { + val predictionsById = predictionsByStop.flatMap { it.value.values }.associateBy { it.id } + return PredictionsStreamDataResponse( + predictions = predictionsById, + trips = trips, + vehicles = vehicles + ) + } +} diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopMessageResponse.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopMessageResponse.kt new file mode 100644 index 000000000..9ad86ffdd --- /dev/null +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopMessageResponse.kt @@ -0,0 +1,15 @@ +package com.mbta.tid.mbta_app.model.response + +import com.mbta.tid.mbta_app.model.Prediction +import com.mbta.tid.mbta_app.model.Trip +import com.mbta.tid.mbta_app.model.Vehicle +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class PredictionsByStopMessageResponse( + @SerialName("stop_id") val stopId: String, + val predictions: Map, + val trips: Map, + val vehicles: Map +) {} diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/phoenix/PredictionsForStopsChannel.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/phoenix/PredictionsForStopsChannel.kt index d5e3455c1..b84e8c28c 100644 --- a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/phoenix/PredictionsForStopsChannel.kt +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/phoenix/PredictionsForStopsChannel.kt @@ -1,6 +1,8 @@ package com.mbta.tid.mbta_app.phoenix import com.mbta.tid.mbta_app.json +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse class PredictionsForStopsChannel { @@ -9,6 +11,8 @@ class PredictionsForStopsChannel { val newDataEvent = "stream_data" + fun topicV2(stopIds: List) = "predictions:stops:v2:${stopIds.joinToString(",")}" + fun joinPayload(stopIds: List): Map { return mapOf("stop_ids" to stopIds) } @@ -17,5 +21,15 @@ class PredictionsForStopsChannel { fun parseMessage(payload: String): PredictionsStreamDataResponse { return json.decodeFromString(payload) } + + @Throws(IllegalArgumentException::class) + fun parseV2JoinMessage(payload: String): PredictionsByStopJoinResponse { + return json.decodeFromString(payload) + } + + @Throws(IllegalArgumentException::class) + fun parseV2Message(payload: String): PredictionsByStopMessageResponse { + return json.decodeFromString(payload) + } } } diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepository.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepository.kt index 5d6ce9b72..bfcc9a60e 100644 --- a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepository.kt +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepository.kt @@ -2,6 +2,8 @@ package com.mbta.tid.mbta_app.repositories import com.mbta.tid.mbta_app.model.Outcome import com.mbta.tid.mbta_app.model.SocketError +import com.mbta.tid.mbta_app.model.response.PredictionsByStopJoinResponse +import com.mbta.tid.mbta_app.model.response.PredictionsByStopMessageResponse import com.mbta.tid.mbta_app.model.response.PredictionsStreamDataResponse import com.mbta.tid.mbta_app.network.PhoenixChannel import com.mbta.tid.mbta_app.network.PhoenixMessage @@ -16,6 +18,12 @@ interface IPredictionsRepository { onReceive: (Outcome) -> Unit ) + fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: (Outcome) -> Unit + ) + fun disconnect() } @@ -46,6 +54,29 @@ class PredictionsRepository(private val socket: PhoenixSocket) : ?.receive(PhoenixPushStatus.Error) { onReceive(Outcome(null, SocketError.Connection)) } } + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: (Outcome) -> Unit, + ) { + disconnect() + channel = socket.getChannel(PredictionsForStopsChannel.topicV2(stopIds), mapOf()) + + channel?.onEvent(PredictionsForStopsChannel.newDataEvent) { message -> + handleV2Message(message, onMessage) + } + channel?.onFailure { onMessage(Outcome(null, SocketError.Unknown)) } + + channel?.onDetach { message -> println("leaving channel ${message.subject}") } + channel + ?.attach() + ?.receive(PhoenixPushStatus.Ok) { message -> + println("joined channel ${message.subject}") + handleV2JoinMessage(message, onJoin) + } + ?.receive(PhoenixPushStatus.Error) { onJoin(Outcome(null, SocketError.Connection)) } + } + override fun disconnect() { channel?.detach() channel = null @@ -71,27 +102,110 @@ class PredictionsRepository(private val socket: PhoenixSocket) : println("No jsonPayload found for message ${message.body}") } } + + private fun handleV2JoinMessage( + message: PhoenixMessage, + onJoin: (Outcome) -> Unit + ) { + val rawPayload: String? = message.jsonBody + + if (rawPayload != null) { + val newPredictionsByStop = + try { + PredictionsForStopsChannel.parseV2JoinMessage(rawPayload) + } catch (e: IllegalArgumentException) { + print("ERROR $e") + onJoin(Outcome(null, SocketError.Unknown)) + return + } + println( + "Received ${newPredictionsByStop.predictionsByStop.values.flatMap { it.values}.size} predictions" + ) + onJoin(Outcome(newPredictionsByStop, null)) + } else { + println("No jsonPayload found for message ${message.body}") + } + } + + /** + * Parse the phoenix message & pass to the onMessage callback + * + * @param message: the message to parse, expected as a PredictionsByStopMessageResponse + * @param onMessage: the callback ot invoke on the parsed message + */ + internal fun handleV2Message( + message: PhoenixMessage, + onMessage: (Outcome) -> Unit + ) { + val rawPayload: String? = message.jsonBody + + if (rawPayload != null) { + val newPredictionsForStop = + try { + PredictionsForStopsChannel.parseV2Message(rawPayload) + } catch (e: IllegalArgumentException) { + onMessage(Outcome(null, SocketError.Unknown)) + return + } + println( + "Received ${newPredictionsForStop.predictions.size} predictions for stop ${newPredictionsForStop.stopId}" + ) + onMessage(Outcome(newPredictionsForStop, null)) + } else { + println("No jsonPayload found for message ${message.body}") + } + } } class MockPredictionsRepository( - private val instantReceive: Outcome? + val onConnect: () -> Unit = {}, + val onConnectV2: () -> Unit = {}, + val onDisconnect: () -> Unit = {}, + private val connectOutcome: Outcome? = null, + private val connectV2Outcome: Outcome? = null ) : IPredictionsRepository { - constructor() : this(instantReceive = null) + + constructor() : + this(onConnect = {}, onConnectV2 = {}, connectOutcome = null, connectV2Outcome = null) constructor( response: PredictionsStreamDataResponse? - ) : this(instantReceive = Outcome(response, null)) + ) : this(connectOutcome = Outcome(response, null)) + + constructor( + onConnect: () -> Unit = {}, + onConnectV2: () -> Unit = {}, + onDisconnect: () -> Unit = {}, + ) : this( + onConnect = onConnect, + onConnectV2 = onConnectV2, + onDisconnect = onDisconnect, + connectOutcome = null, + connectV2Outcome = null + ) override fun connect( stopIds: List, onReceive: (Outcome) -> Unit ) { - if (instantReceive != null) { - onReceive(instantReceive) + onConnect() + if (connectOutcome != null) { + onReceive(connectOutcome) + } + } + + override fun connectV2( + stopIds: List, + onJoin: (Outcome) -> Unit, + onMessage: (Outcome) -> Unit + ) { + onConnectV2() + if (connectV2Outcome != null) { + onJoin(connectV2Outcome) } } override fun disconnect() { - /* no-op */ + onDisconnect() } } diff --git a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/SettingsRepository.kt b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/SettingsRepository.kt index 63e415bcb..9b4972203 100644 --- a/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/SettingsRepository.kt +++ b/shared/src/commonMain/kotlin/com/mbta/tid/mbta_app/repositories/SettingsRepository.kt @@ -38,6 +38,7 @@ enum class Settings(val dataStoreKey: Preferences.Key) { Map(booleanPreferencesKey("map_debug")), Search(booleanPreferencesKey("search_featureFlag")), SearchRouteResults(booleanPreferencesKey("searchRouteResults_featureFlag")), + PredictionsV2Channel(booleanPreferencesKey("predictions_v2Channel")) } data class Setting(val key: Settings, var isOn: Boolean) diff --git a/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopJoinResponseTest.kt b/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopJoinResponseTest.kt new file mode 100644 index 000000000..8175b9bcb --- /dev/null +++ b/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/model/response/PredictionsByStopJoinResponseTest.kt @@ -0,0 +1,158 @@ +package com.mbta.tid.mbta_app.model.response + +import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder.Single.prediction +import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder.Single.trip +import com.mbta.tid.mbta_app.model.ObjectCollectionBuilder.Single.vehicle +import com.mbta.tid.mbta_app.model.Prediction +import com.mbta.tid.mbta_app.model.Vehicle +import kotlin.test.Test +import kotlin.test.assertEquals + +class PredictionsByStopJoinResponseTest { + @Test + fun `mergePredictions replaces predictions for existing stop`() { + val trip = trip() + val vehicle = vehicle { currentStatus = Vehicle.CurrentStatus.StoppedAt } + + val p1stop1 = prediction { + stopId = "1" + tripId = trip.id + vehicleId = vehicle.id + } + + val p1stop1updated = prediction { + id = p1stop1.id + stopId = "1" + tripId = trip.id + vehicleId = vehicle.id + scheduleRelationship = Prediction.ScheduleRelationship.Cancelled + } + val p2stop1 = prediction { + stopId = "1" + tripId = trip.id + vehicleId = vehicle.id + } + + val p1stop2 = prediction { + stopId = "2" + tripId = trip.id + vehicleId = vehicle.id + } + val p2stop2 = prediction { + stopId = "2" + tripId = trip.id + vehicleId = vehicle.id + } + + val existingPredictions = + PredictionsByStopJoinResponse( + predictionsByStop = + mapOf( + "1" to mapOf(p1stop1.id to p1stop1), + "2" to mapOf(p1stop2.id to p1stop2, p2stop2.id to p2stop2) + ), + trips = mapOf(trip.id to trip), + vehicles = mapOf(vehicle.id to vehicle) + ) + + val stop1NewPredictions = + PredictionsByStopMessageResponse( + stopId = "1", + predictions = mapOf(p1stop1updated.id to p1stop1updated, p2stop1.id to p2stop1), + trips = mapOf(), + vehicles = mapOf() + ) + + val result = existingPredictions.mergePredictions(stop1NewPredictions) + + assertEquals( + mapOf(p1stop1updated.id to p1stop1updated, p2stop1.id to p2stop1), + result.predictionsByStop["1"] + ) + assertEquals( + mapOf(p1stop2.id to p1stop2, p2stop2.id to p2stop2), + result.predictionsByStop["2"] + ) + assertEquals(mapOf(trip.id to trip), result.trips) + assertEquals(mapOf(vehicle.id to vehicle), result.vehicles) + } + + @Test + fun `mergePredictions removes trips and vehicles that are no longer referenced`() { + val trip1 = trip() + val vehicle1 = vehicle { currentStatus = Vehicle.CurrentStatus.StoppedAt } + + val trip2 = trip() + val vehicle2 = vehicle { currentStatus = Vehicle.CurrentStatus.StoppedAt } + + val p1stop1 = prediction { + stopId = "1" + tripId = trip1.id + vehicleId = vehicle1.id + } + + val p1stop2 = prediction { + stopId = "2" + tripId = trip2.id + vehicleId = vehicle2.id + } + + val existingPredictions = + PredictionsByStopJoinResponse( + predictionsByStop = + mapOf("1" to mapOf(p1stop1.id to p1stop1), "2" to mapOf(p1stop2.id to p1stop2)), + trips = mapOf(trip1.id to trip1, trip2.id to trip2), + vehicles = mapOf(vehicle1.id to vehicle1, vehicle2.id to vehicle2) + ) + + val stop1NewPredictions = + PredictionsByStopMessageResponse( + stopId = "1", + predictions = mapOf(), + trips = mapOf(), + vehicles = mapOf() + ) + + val result = existingPredictions.mergePredictions(stop1NewPredictions) + + assertEquals(mapOf(), result.predictionsByStop["1"]) + assertEquals(mapOf(p1stop2.id to p1stop2), result.predictionsByStop["2"]) + assertEquals(mapOf(trip2.id to trip2), result.trips) + assertEquals(mapOf(vehicle2.id to vehicle2), result.vehicles) + } + + @Test + fun `toPredictionsStreamDataResponse flattens predictionsByStop preserves trips + vehicles`() { + val trip1 = trip() + val vehicle1 = vehicle { currentStatus = Vehicle.CurrentStatus.StoppedAt } + + val trip2 = trip() + val vehicle2 = vehicle { currentStatus = Vehicle.CurrentStatus.StoppedAt } + + val p1stop1 = prediction { + stopId = "1" + tripId = trip1.id + vehicleId = vehicle1.id + } + + val p1stop2 = prediction { + stopId = "2" + tripId = trip2.id + vehicleId = vehicle2.id + } + + val data = + PredictionsByStopJoinResponse( + predictionsByStop = + mapOf("1" to mapOf(p1stop1.id to p1stop1), "2" to mapOf(p1stop2.id to p1stop2)), + trips = mapOf(trip1.id to trip1, trip2.id to trip2), + vehicles = mapOf(vehicle1.id to vehicle1, vehicle2.id to vehicle2) + ) + + val response = data.toPredictionsStreamDataResponse() + + assertEquals(mapOf(p1stop1.id to p1stop1, p1stop2.id to p1stop2), response.predictions) + assertEquals(mapOf(trip1.id to trip1, trip2.id to trip2), response.trips) + assertEquals(mapOf(vehicle1.id to vehicle1, vehicle2.id to vehicle2), response.vehicles) + } +} diff --git a/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepositoryTests.kt b/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepositoryTests.kt index 72fec34b6..532b6fa26 100644 --- a/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepositoryTests.kt +++ b/shared/src/commonTest/kotlin/com/mbta/tid/mbta_app/repositories/PredictionsRepositoryTests.kt @@ -14,6 +14,7 @@ import dev.mokkery.answering.returns import dev.mokkery.every import dev.mokkery.matcher.any import dev.mokkery.mock +import dev.mokkery.verify import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -124,4 +125,300 @@ class PredictionsRepositoryTests : KoinTest { } ) } + + @Test + fun testV2ChannelSetOnRun() { + val socket = mock(MockMode.autofill) + val channel = mock(MockMode.autofill) + val push = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + every { channel.attach() } returns push + + every { push.receive(any(), any()) } returns push + every { socket.getChannel("predictions:stops:v2:1,2", any()) } returns channel + assertNull(predictionsRepo.channel) + predictionsRepo.connectV2( + stopIds = listOf("1", "2"), + onJoin = { /* no-op */}, + onMessage = { /* no-op */} + ) + + assertNotNull(predictionsRepo.channel) + } + + @Test + fun testV2ChannelJoinTwiceLeavesOldChannel() { + val socket = mock(MockMode.autofill) + val channel = mock(MockMode.autofill) + val push = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + every { channel.attach() } returns push + + every { push.receive(any(), any()) } returns push + every { socket.getChannel(any(), any()) } returns channel + assertNull(predictionsRepo.channel) + predictionsRepo.connectV2( + stopIds = listOf("1", "2"), + onJoin = { /* no-op */}, + onMessage = { /* no-op */} + ) + + assertNotNull(predictionsRepo.channel) + + predictionsRepo.connectV2( + stopIds = listOf("3", "4"), + onJoin = { /* no-op */}, + onMessage = { /* no-op */} + ) + + verify { channel.detach() } + } + + @Test + fun testV2ChannelClearedOnLeave() { + val socket = mock(MockMode.autofill) + val channel = mock(MockMode.autofill) + val push = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + every { channel.attach() } returns push + + every { push.receive(any(), any()) } returns push + every { socket.getChannel(any(), any()) } returns channel + assertNull(predictionsRepo.channel) + predictionsRepo.connectV2( + stopIds = listOf("1", "2"), + onJoin = { /* no-op */}, + onMessage = { /* no-op */} + ) + + assertNotNull(predictionsRepo.channel) + + predictionsRepo.disconnect() + + verify { channel.detach() } + } + + @Test + fun testV2SetsPredictionsOnJoinResponse() { + class MockPush : PhoenixPush { + override fun receive( + status: PhoenixPushStatus, + callback: (PhoenixMessage) -> Unit + ): PhoenixPush { + if (status == PhoenixPushStatus.Ok) { + callback( + MockMessage( + jsonBody = + """ + {"predictions_by_stop": + {"12345": + { + "p_1": { + "id": "p_1", + "arrival_time": null, + "departure_time": null, + "direction_id": 0, + "revenue": false, + "schedule_relationship": "scheduled", + "status": null, + "route_id": "66", + "stop_id": "12345", + "trip_id": "t_1", + "vehicle_id": "v_1", + "stop_sequence": 38 + } + } + }, + "trips": { + "t_1": { + "id": "t_1", + "direction_id": 0, + "headsign": "Nubian", + "route_id": "66", + "route_pattern_id": "66-0-0", + "shape_id": "shape_id", + "stop_ids": [] + } + }, + "vehicles": { + "v_1": { + "id": "v_1", + "bearing": 351, + "current_status": "in_transit_to", + "current_stop_sequence": 17, + "direction_id": 0, + "route_id": "66", + "trip_id": "t_1", + "stop_id": "12345", + "latitude": 42.34114183, + "longitude": -71.121119039, + "updated_at": "2024-09-23T11:30:26-04:00" + + } + } + } + """ + .trimIndent() + ) + ) + } + return this + } + } + val socket = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + val channel = mock(MockMode.autofill) + val push = MockPush() + every { socket.getChannel(any(), any()) } returns channel + every { channel.attach() } returns push + + predictionsRepo.connectV2( + stopIds = listOf("1"), + onJoin = { outcome -> + outcome.data?.let { + assertEquals(1, it.predictionsByStop.size) + assertEquals("p_1", it.predictionsByStop["12345"]?.get("p_1")?.id) + + assertEquals(1, it.trips.size) + assertEquals("t_1", it.trips["t_1"]?.id) + + assertEquals(1, it.vehicles.size) + assertEquals("v_1", it.vehicles["v_1"]?.id) + } + outcome.error?.let { fail() } + }, + onMessage = { /* no-op */} + ) + } + + @Test + fun testV2HandleV2Message() { + + val message = + MockMessage( + jsonBody = + """ + { + "stop_id": "12345", + "predictions": + { + "p_1": { + "id": "p_1", + "arrival_time": null, + "departure_time": null, + "direction_id": 0, + "revenue": false, + "schedule_relationship": "scheduled", + "status": null, + "route_id": "66", + "stop_id": "12345", + "trip_id": "t_1", + "vehicle_id": "v_1", + "stop_sequence": 38 + } + }, + "trips": { + "t_1": { + "id": "t_1", + "direction_id": 0, + "headsign": "Nubian", + "route_id": "66", + "route_pattern_id": "66-0-0", + "shape_id": "shape_id", + "stop_ids": [] + } + }, + "vehicles": { + "v_1": { + "id": "v_1", + "bearing": 351, + "current_status": "in_transit_to", + "current_stop_sequence": 17, + "direction_id": 0, + "route_id": "66", + "trip_id": "t_1", + "stop_id": "12345", + "latitude": 42.34114183, + "longitude": -71.121119039, + "updated_at": "2024-09-23T11:30:26-04:00" + + } + } + } + """ + .trimIndent() + ) + val socket = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + predictionsRepo.handleV2Message( + message, + onMessage = { outcome -> + outcome.data?.let { + assertEquals("12345", it.stopId) + assertEquals("p_1", it.predictions["p_1"]?.id) + + assertEquals(1, it.trips.size) + assertEquals("t_1", it.trips["t_1"]?.id) + + assertEquals(1, it.vehicles.size) + assertEquals("v_1", it.vehicles["v_1"]?.id) + } + outcome.error?.let { fail() } + } + ) + } + + @Test + fun testV2SetsErrorWhenReceivedOnJoin() { + val socket = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + val push = mock(MockMode.autofill) + every { push.receive(any(), any()) } returns push + class MockChannel : PhoenixChannel { + override fun onEvent(event: String, callback: (PhoenixMessage) -> Unit) { + /* no-op */ + } + + override fun onFailure(callback: (message: PhoenixMessage) -> Unit) { + callback(MockMessage()) + } + + override fun onDetach(callback: (PhoenixMessage) -> Unit) { + /* no-op */ + } + + override fun attach(): PhoenixPush { + return push + } + + override fun detach(): PhoenixPush { + return push + } + } + every { socket.getChannel(any(), any()) } returns MockChannel() + predictionsRepo.connectV2( + stopIds = listOf("1"), + onJoin = { outcome -> + assertNotNull(outcome.error) + assertEquals(outcome.error, SocketError.Unknown) + }, + onMessage = { /* no-op */} + ) + } + + @Test + fun testV2SetsErrorWhenReceivedOnMessage() { + + val message = MockMessage(jsonBody = "BAD_DATA") + val socket = mock(MockMode.autofill) + val predictionsRepo = PredictionsRepository(socket) + predictionsRepo.handleV2Message( + message, + onMessage = { outcome -> + outcome.data?.let { fail() } + + outcome.error?.let { error -> assertEquals(SocketError.Unknown, error) } + } + ) + } }