diff --git a/moqt-core/src/modules/messages/data_streams/object_stream_subgroup.rs b/moqt-core/src/modules/messages/data_streams/object_stream_subgroup.rs index 91ecc93..f6cd2b2 100644 --- a/moqt-core/src/modules/messages/data_streams/object_stream_subgroup.rs +++ b/moqt-core/src/modules/messages/data_streams/object_stream_subgroup.rs @@ -47,6 +47,10 @@ impl ObjectStreamSubgroup { pub fn object_id(&self) -> u64 { self.object_id } + + pub fn object_status(&self) -> Option { + self.object_status + } } impl DataStreams for ObjectStreamSubgroup { diff --git a/moqt-core/src/modules/messages/data_streams/object_stream_track.rs b/moqt-core/src/modules/messages/data_streams/object_stream_track.rs index a8de220..7b80673 100644 --- a/moqt-core/src/modules/messages/data_streams/object_stream_track.rs +++ b/moqt-core/src/modules/messages/data_streams/object_stream_track.rs @@ -54,6 +54,10 @@ impl ObjectStreamTrack { pub fn object_id(&self) -> u64 { self.object_id } + + pub fn object_status(&self) -> Option { + self.object_status + } } impl DataStreams for ObjectStreamTrack { diff --git a/moqt-core/src/modules/pubsub_relation_manager_repository.rs b/moqt-core/src/modules/pubsub_relation_manager_repository.rs index a1ede16..edae97a 100644 --- a/moqt-core/src/modules/pubsub_relation_manager_repository.rs +++ b/moqt-core/src/modules/pubsub_relation_manager_repository.rs @@ -53,6 +53,11 @@ pub trait PubSubRelationManagerRepository: Send + Sync { track_namespace: Vec, track_name: String, ) -> Result>; + async fn get_upstream_subscription_by_ids( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result>; async fn get_downstream_subscription_by_ids( &self, downstream_session_id: usize, diff --git a/moqt-server/src/modules/message_handlers/object_stream.rs b/moqt-server/src/modules/message_handlers/object_stream.rs index ffd62c5..d2b19eb 100644 --- a/moqt-server/src/modules/message_handlers/object_stream.rs +++ b/moqt-server/src/modules/message_handlers/object_stream.rs @@ -17,7 +17,7 @@ use std::io::Cursor; #[derive(Debug, PartialEq)] pub enum ObjectStreamProcessResult { - Success, + Success(CacheObject), IncompleteMessage, Failure(TerminationErrorCode, String), } @@ -61,16 +61,18 @@ pub async fn object_stream_handler( Ok(object) => { read_buf.advance(read_cur.position() as usize); - let cache_object = CacheObject::Track(object); + let received_object = CacheObject::Track(object); object_cache_storage - .set_object(client.id(), subscribe_id, cache_object, duration) + .set_object(client.id(), subscribe_id, received_object.clone(), duration) .await .unwrap(); + + ObjectStreamProcessResult::Success(received_object) } Err(err) => { tracing::warn!("{:#?}", err); read_cur.set_position(0); - return ObjectStreamProcessResult::IncompleteMessage; + ObjectStreamProcessResult::IncompleteMessage } } } @@ -80,26 +82,24 @@ pub async fn object_stream_handler( Ok(object) => { read_buf.advance(read_cur.position() as usize); - let cache_object = CacheObject::Subgroup(object); + let received_object = CacheObject::Subgroup(object); object_cache_storage - .set_object(client.id(), subscribe_id, cache_object, duration) + .set_object(client.id(), subscribe_id, received_object.clone(), duration) .await .unwrap(); + + ObjectStreamProcessResult::Success(received_object) } Err(err) => { tracing::warn!("{:#?}", err); read_cur.set_position(0); - return ObjectStreamProcessResult::IncompleteMessage; + ObjectStreamProcessResult::IncompleteMessage } } } - unknown => { - return ObjectStreamProcessResult::Failure( - TerminationErrorCode::ProtocolViolation, - format!("Unknown message type: {:?}", unknown), - ); - } - }; - - ObjectStreamProcessResult::Success + unknown => ObjectStreamProcessResult::Failure( + TerminationErrorCode::ProtocolViolation, + format!("Unknown message type: {:?}", unknown), + ), + } } diff --git a/moqt-server/src/modules/message_handlers/stream_header.rs b/moqt-server/src/modules/message_handlers/stream_header.rs index 0283175..0b66dec 100644 --- a/moqt-server/src/modules/message_handlers/stream_header.rs +++ b/moqt-server/src/modules/message_handlers/stream_header.rs @@ -9,7 +9,7 @@ use crate::{ stream_track_subgroup::process_stream_header_subgroup, }, moqt_client::{MOQTClient, MOQTClientStatus}, - object_cache_storage::ObjectCacheStorageWrapper, + object_cache_storage::{CacheHeader, ObjectCacheStorageWrapper}, }, }; use anyhow::{bail, Result}; @@ -23,7 +23,7 @@ use std::io::Cursor; #[derive(Debug, PartialEq)] pub enum StreamHeaderProcessResult { - Success((u64, DataStreamType)), + Success(CacheHeader), IncompleteMessage, Failure(TerminationErrorCode, String), } @@ -92,7 +92,7 @@ pub async fn stream_header_handler( ); } - let subscribe_id = match header_type { + match header_type { DataStreamType::StreamHeaderTrack => { match process_stream_header_track( &mut read_cur, @@ -102,16 +102,17 @@ pub async fn stream_header_handler( ) .await { - Ok(subscribe_id) => { + Ok(received_header) => { read_buf.advance(read_cur.position() as usize); - subscribe_id + + StreamHeaderProcessResult::Success(received_header) } Err(err) => { read_buf.advance(read_cur.position() as usize); - return StreamHeaderProcessResult::Failure( + StreamHeaderProcessResult::Failure( TerminationErrorCode::InternalError, err.to_string(), - ); + ) } } } @@ -124,26 +125,23 @@ pub async fn stream_header_handler( ) .await { - Ok(subscribe_id) => { + Ok(received_header) => { read_buf.advance(read_cur.position() as usize); - subscribe_id + + StreamHeaderProcessResult::Success(received_header) } Err(err) => { read_buf.advance(read_cur.position() as usize); - return StreamHeaderProcessResult::Failure( + StreamHeaderProcessResult::Failure( TerminationErrorCode::InternalError, err.to_string(), - ); + ) } } } - unknown => { - return StreamHeaderProcessResult::Failure( - TerminationErrorCode::ProtocolViolation, - format!("Unknown message type: {:?}", unknown), - ); - } - }; - - StreamHeaderProcessResult::Success((subscribe_id, header_type)) + unknown => StreamHeaderProcessResult::Failure( + TerminationErrorCode::ProtocolViolation, + format!("Unknown message type: {:?}", unknown), + ), + } } diff --git a/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_subgroup_header_handler.rs b/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_subgroup_header_handler.rs index 11f291d..07cb76d 100644 --- a/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_subgroup_header_handler.rs +++ b/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_subgroup_header_handler.rs @@ -16,7 +16,7 @@ pub(crate) async fn stream_header_subgroup_handler( pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, object_cache_storage: &mut ObjectCacheStorageWrapper, client: &MOQTClient, -) -> Result { +) -> Result { tracing::trace!("stream_header_subgroup_handler start."); tracing::debug!( @@ -37,8 +37,12 @@ pub(crate) async fn stream_header_subgroup_handler( let cache_header = CacheHeader::Subgroup(stream_header_subgroup_message); object_cache_storage - .set_subscription(upstream_session_id, upstream_subscribe_id, cache_header) + .set_subscription( + upstream_session_id, + upstream_subscribe_id, + cache_header.clone(), + ) .await?; - Ok(upstream_subscribe_id) + Ok(cache_header) } diff --git a/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_track_header_handler.rs b/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_track_header_handler.rs index 617e8ad..4cf56a3 100644 --- a/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_track_header_handler.rs +++ b/moqt-server/src/modules/message_handlers/stream_header/handlers/stream_track_header_handler.rs @@ -16,7 +16,7 @@ pub(crate) async fn stream_header_track_handler( pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, object_cache_storage: &mut ObjectCacheStorageWrapper, client: &MOQTClient, -) -> Result { +) -> Result { tracing::trace!("stream_header_track_handler start."); tracing::debug!( @@ -37,8 +37,12 @@ pub(crate) async fn stream_header_track_handler( let cache_header = CacheHeader::Track(stream_header_track_message); object_cache_storage - .set_subscription(upstream_session_id, upstream_subscribe_id, cache_header) + .set_subscription( + upstream_session_id, + upstream_subscribe_id, + cache_header.clone(), + ) .await?; - Ok(upstream_subscribe_id) + Ok(cache_header) } diff --git a/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_header.rs b/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_header.rs index 215c0e8..bbdf715 100644 --- a/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_header.rs +++ b/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_header.rs @@ -7,7 +7,8 @@ use moqt_core::{ use crate::modules::{ message_handlers::stream_header::handlers::stream_track_header_handler::stream_header_track_handler, - moqt_client::MOQTClient, object_cache_storage::ObjectCacheStorageWrapper, + moqt_client::MOQTClient, + object_cache_storage::{CacheHeader, ObjectCacheStorageWrapper}, }; pub(crate) async fn process_stream_header_track( @@ -15,7 +16,7 @@ pub(crate) async fn process_stream_header_track( pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, object_cache_storage: &mut ObjectCacheStorageWrapper, client: &MOQTClient, -) -> Result { +) -> Result { let stream_header_track = match StreamHeaderTrack::depacketize(read_cur) { Ok(stream_header_track) => stream_header_track, Err(err) => { diff --git a/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_subgroup.rs b/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_subgroup.rs index b39760c..24ac202 100644 --- a/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_subgroup.rs +++ b/moqt-server/src/modules/message_handlers/stream_header/server_processes/stream_track_subgroup.rs @@ -7,7 +7,8 @@ use moqt_core::{ use crate::modules::{ message_handlers::stream_header::handlers::stream_subgroup_header_handler::stream_header_subgroup_handler, - moqt_client::MOQTClient, object_cache_storage::ObjectCacheStorageWrapper, + moqt_client::MOQTClient, + object_cache_storage::{CacheHeader, ObjectCacheStorageWrapper}, }; pub(crate) async fn process_stream_header_subgroup( @@ -15,7 +16,7 @@ pub(crate) async fn process_stream_header_subgroup( pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, object_cache_storage: &mut ObjectCacheStorageWrapper, client: &MOQTClient, -) -> Result { +) -> Result { let stream_header_subgroup = match StreamHeaderSubgroup::depacketize(read_cur) { Ok(stream_header_subgroup) => stream_header_subgroup, Err(err) => { diff --git a/moqt-server/src/modules/object_cache_storage.rs b/moqt-server/src/modules/object_cache_storage.rs index 7099c47..4af0409 100644 --- a/moqt-server/src/modules/object_cache_storage.rs +++ b/moqt-server/src/modules/object_cache_storage.rs @@ -22,7 +22,7 @@ pub(crate) enum CacheHeader { } #[allow(dead_code)] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub(crate) enum CacheObject { Datagram(ObjectDatagram), Track(ObjectStreamTrack), diff --git a/moqt-server/src/modules/pubsub_relation_manager/commands.rs b/moqt-server/src/modules/pubsub_relation_manager/commands.rs index ed9542d..afceddd 100644 --- a/moqt-server/src/modules/pubsub_relation_manager/commands.rs +++ b/moqt-server/src/modules/pubsub_relation_manager/commands.rs @@ -59,6 +59,11 @@ pub(crate) enum PubSubRelationCommand { track_name: String, resp: oneshot::Sender>>, }, + GetUpstreamSubscriptionBySessionIdAndSubscribeId { + upstream_session_id: usize, + upstream_subscribe_id: u64, + resp: oneshot::Sender>>, + }, GetDownstreamSubscriptionBySessionIdAndSubscribeId { downstream_session_id: usize, downstream_subscribe_id: u64, diff --git a/moqt-server/src/modules/pubsub_relation_manager/manager.rs b/moqt-server/src/modules/pubsub_relation_manager/manager.rs index 520b40c..8d69cfd 100644 --- a/moqt-server/src/modules/pubsub_relation_manager/manager.rs +++ b/moqt-server/src/modules/pubsub_relation_manager/manager.rs @@ -288,6 +288,16 @@ pub(crate) async fn pubsub_relation_manager(rx: &mut mpsc::Receiver { + let consumer = consumers.get(&upstream_session_id).unwrap(); + let result = consumer.get_subscription(upstream_subscribe_id); + + resp.send(result).unwrap(); + } GetDownstreamSubscriptionBySessionIdAndSubscribeId { downstream_session_id, downstream_subscribe_id, diff --git a/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs b/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs index 94d768f..94e790d 100644 --- a/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs +++ b/moqt-server/src/modules/pubsub_relation_manager/wrapper.rs @@ -210,6 +210,26 @@ impl PubSubRelationManagerRepository for PubSubRelationManagerWrapper { Err(err) => bail!(err), } } + async fn get_upstream_subscription_by_ids( + &self, + upstream_session_id: usize, + upstream_subscribe_id: u64, + ) -> Result> { + let (resp_tx, resp_rx) = oneshot::channel::>>(); + let cmd = PubSubRelationCommand::GetUpstreamSubscriptionBySessionIdAndSubscribeId { + upstream_session_id, + upstream_subscribe_id, + resp: resp_tx, + }; + self.tx.send(cmd).await.unwrap(); + + let result = resp_rx.await.unwrap(); + + match result { + Ok(subscription) => Ok(subscription), + Err(err) => bail!(err), + } + } async fn get_downstream_subscription_by_ids( &self, downstream_session_id: usize, @@ -1114,6 +1134,68 @@ mod success { assert_eq!(subscription, Some(expected_subscription)); } + #[tokio::test] + async fn get_upstream_subscription_by_ids() { + let max_subscribe_id = 10; + let upstream_session_id = 1; + let upstream_subscribe_id = 0; + let track_alias = 0; + let track_namespace = Vec::from(["test".to_string(), "test".to_string()]); + let track_name = "track_name".to_string(); + let subscriber_priority = 0; + let group_order = GroupOrder::Ascending; + let filter_type = FilterType::AbsoluteStart; + let start_group = Some(0); + let start_object = Some(0); + let end_group = None; + let end_object = None; + + // Start track management thread + let (track_tx, mut track_rx) = mpsc::channel::(1024); + tokio::spawn(async move { pubsub_relation_manager(&mut track_rx).await }); + + let pubsub_relation_manager = PubSubRelationManagerWrapper::new(track_tx.clone()); + let _ = pubsub_relation_manager + .setup_publisher(max_subscribe_id, upstream_session_id) + .await; + let _ = pubsub_relation_manager + .set_upstream_subscription( + upstream_session_id, + track_namespace.clone(), + track_name.clone(), + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + end_object, + ) + .await; + + let subscription = pubsub_relation_manager + .get_upstream_subscription_by_ids(upstream_session_id, upstream_subscribe_id) + .await + .unwrap(); + + let forwarding_preference = None; + let expected_subscription = Subscription::new( + track_alias, + track_namespace, + track_name, + subscriber_priority, + group_order, + filter_type, + start_group, + start_object, + end_group, + end_object, + forwarding_preference, + ); + + assert_eq!(subscription, Some(expected_subscription)); + } + #[tokio::test] async fn get_downstream_subscription_by_ids() { let max_subscribe_id = 10; diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/receiver.rs b/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/receiver.rs index 17d1d52..e605c53 100644 --- a/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/receiver.rs +++ b/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/receiver.rs @@ -7,7 +7,7 @@ use crate::{ stream_header::{stream_header_handler, StreamHeaderProcessResult}, }, moqt_client::MOQTClient, - object_cache_storage::ObjectCacheStorageWrapper, + object_cache_storage::{CacheHeader, CacheObject, ObjectCacheStorageWrapper}, pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, server_processes::senders::Senders, }, @@ -16,7 +16,12 @@ use crate::{ use anyhow::Result; use bytes::BytesMut; use moqt_core::{ - constants::TerminationErrorCode, data_stream_type::DataStreamType, + constants::TerminationErrorCode, + data_stream_type::DataStreamType, + messages::{ + control_messages::subscribe::FilterType, data_streams::object_status::ObjectStatus, + }, + models::subscriptions::Subscription, pubsub_relation_manager_repository::PubSubRelationManagerRepository, }; use std::sync::Arc; @@ -30,6 +35,7 @@ pub(crate) struct UniStreamReceiver { client: Arc>, subscribe_id: Option, stream_header_type: Option, + upstream_subscription: Option, } impl UniStreamReceiver { @@ -46,6 +52,7 @@ impl UniStreamReceiver { client, subscribe_id: None, stream_header_type: None, + upstream_subscription: None, } } @@ -55,12 +62,7 @@ impl UniStreamReceiver { Ok(()) } - pub(crate) async fn terminate(&self, code: TerminationErrorCode, reason: String) -> Result<()> { - self.senders - .close_session_tx() - .send((u8::from(code) as u64, reason.to_string())) - .await?; - + pub(crate) async fn terminate(&self) -> Result<()> { self.senders .buffer_tx() .send(BufferCommand::ReleaseStream { @@ -69,6 +71,8 @@ impl UniStreamReceiver { }) .await?; + tracing::debug!("Terminated UniStreamReceiver"); + Ok(()) } @@ -77,16 +81,28 @@ impl UniStreamReceiver { ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); let mut header_read = false; + let mut is_end = false; + + // If the received object is subgroup, store group id to judge the end of range. + let mut subgroup_group_id: Option = None; loop { + if is_end { + break; + } + self.read_stream_and_add_to_buf().await?; - self.read_buf_to_end_and_store_to_object_cache( - &mut header_read, - &mut object_cache_storage, - ) - .await?; + is_end = self + .read_buf_to_end_and_store_to_object_cache( + &mut header_read, + &mut object_cache_storage, + &mut subgroup_group_id, + ) + .await?; } + + Ok(()) } async fn read_stream_and_add_to_buf(&mut self) -> Result<(), TerminationError> { @@ -114,23 +130,26 @@ impl UniStreamReceiver { &mut self, header_read: &mut bool, object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(), TerminationError> { + subgroup_group_id: &mut Option, + ) -> Result { loop { if !*header_read { match self .read_header_from_buf_and_store_to_object_cache(object_cache_storage) .await { - Ok((is_succeeded, header_params)) => { + Ok((is_succeeded, received_header)) => { if !is_succeeded { break; } else { *header_read = true; - let (subscribe_id, header_type) = header_params.unwrap(); - - self.subscribe_id = Some(subscribe_id); - self.stream_header_type = Some(header_type); + let received_header = received_header.unwrap(); + self.set_rest_parameters_from_header( + received_header, + subgroup_group_id, + ) + .await?; } } Err(err) => { @@ -146,10 +165,22 @@ impl UniStreamReceiver { .read_object_stream_and_store_to_object_cache(object_cache_storage) .await { - Ok(is_succeeded) => { - if !is_succeeded { + Ok(received_object) => { + if received_object.is_none() { + // return to read consequent data from stream break; } + + let is_end = self + .judge_end_of_receiving( + received_object.as_ref().unwrap(), + subgroup_group_id, + ) + .await?; + + if is_end { + return Ok(true); + } } Err(err) => { let msg = format!("Fail to read object stream from buf: {:?}", err); @@ -160,17 +191,37 @@ impl UniStreamReceiver { } } - Ok(()) + Ok(false) } async fn read_header_from_buf_and_store_to_object_cache( &self, object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result<(bool, Option<(u64, DataStreamType)>), TerminationError> { + ) -> Result<(bool, Option), TerminationError> { let result = self.try_to_read_buf_as_header(object_cache_storage).await; match result { - StreamHeaderProcessResult::Success((subscribe_id, header_type)) => { + StreamHeaderProcessResult::Success(received_header) => { + let (subscribe_id, header_type) = match &received_header { + CacheHeader::Track(header) => { + let header_type = DataStreamType::StreamHeaderTrack; + let subscribe_id = header.subscribe_id(); + + (subscribe_id, header_type) + } + CacheHeader::Subgroup(header) => { + let header_type = DataStreamType::StreamHeaderSubgroup; + let subscribe_id = header.subscribe_id(); + + (subscribe_id, header_type) + } + _ => { + let msg = "received header not matched".to_string(); + let code = TerminationErrorCode::InternalError; + return Err((code, msg)); + } + }; + match self .open_downstream_uni_stream(subscribe_id, &header_type) .await @@ -183,7 +234,7 @@ impl UniStreamReceiver { return Err((code, msg)); } }; - Ok((true, Some((subscribe_id, header_type)))) + Ok((true, Some(received_header))) } StreamHeaderProcessResult::IncompleteMessage => Ok((false, None)), StreamHeaderProcessResult::Failure(code, reason) => { @@ -242,17 +293,81 @@ impl UniStreamReceiver { Ok(()) } + async fn set_rest_parameters_from_header( + &mut self, + received_header: CacheHeader, + subgroup_group_id: &mut Option, + ) -> Result<(), TerminationError> { + // Set subscribe_id and stream_header_type from received header + let (subscribe_id, header_type) = match &received_header { + CacheHeader::Track(header) => { + let header_type = DataStreamType::StreamHeaderTrack; + let subscribe_id = header.subscribe_id(); + + (subscribe_id, header_type) + } + CacheHeader::Subgroup(header) => { + let header_type = DataStreamType::StreamHeaderSubgroup; + let subscribe_id = header.subscribe_id(); + + (subscribe_id, header_type) + } + _ => { + let msg = "received header not matched".to_string(); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + self.subscribe_id = Some(subscribe_id); + self.stream_header_type = Some(header_type); + + // Set upstream_subscription from received header + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let upstream_subscription = match pubsub_relation_manager + .get_upstream_subscription_by_ids(self.stream.stable_id(), subscribe_id) + .await + { + Ok(upstream_subscription) => { + if upstream_subscription.is_none() { + let msg = "Upstream subscription not found".to_string(); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + + upstream_subscription.unwrap() + } + + Err(err) => { + let msg = format!("Fail to get upstream subscription: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + self.upstream_subscription = Some(upstream_subscription); + + // Set group id if the received object is subgroup + if let CacheHeader::Subgroup(header) = received_header { + *subgroup_group_id = Some(header.group_id()); + } + + Ok(()) + } + async fn read_object_stream_and_store_to_object_cache( &self, object_cache_storage: &mut ObjectCacheStorageWrapper, - ) -> Result { + ) -> Result, TerminationError> { let result = self .try_to_read_buf_as_object_stream_and_store_to_object_cache(object_cache_storage) .await; match result { - ObjectStreamProcessResult::Success => Ok(true), - ObjectStreamProcessResult::IncompleteMessage => Ok(false), + ObjectStreamProcessResult::Success(received_object) => Ok(Some(received_object)), + ObjectStreamProcessResult::IncompleteMessage => Ok(None), ObjectStreamProcessResult::Failure(code, reason) => { let msg = std::format!("object_stream_read failure: {:?}", reason); Err((code, msg)) @@ -278,4 +393,97 @@ impl UniStreamReceiver { ) .await } + + async fn judge_end_of_receiving( + &self, + received_object: &CacheObject, + subgroup_group_id: &Option, + ) -> Result { + let is_end_of_data_stream = self.judge_end_of_data_stream(received_object).await?; + if is_end_of_data_stream { + return Ok(true); + } + + let upstream_subscription = self.upstream_subscription.as_ref().unwrap(); + let filter_type = upstream_subscription.get_filter_type(); + if filter_type == FilterType::AbsoluteRange { + let is_end_of_absolute_range = self + .judge_end_of_absolute_range(received_object, subgroup_group_id) + .await?; + if is_end_of_absolute_range { + return Ok(true); + } + } + + Ok(false) + } + + async fn judge_end_of_data_stream( + &self, + received_object: &CacheObject, + ) -> Result { + let is_end = match received_object { + CacheObject::Track(object_stream_track) => { + matches!( + object_stream_track.object_status(), + Some(ObjectStatus::EndOfTrackAndGroup) + ) + } + CacheObject::Subgroup(object_stream_subgroup) => { + matches!( + object_stream_subgroup.object_status(), + Some(ObjectStatus::EndOfSubgroup) + | Some(ObjectStatus::EndOfGroup) + | Some(ObjectStatus::EndOfTrackAndGroup) + ) + } + _ => { + let msg = "received object not matched".to_string(); + let code = TerminationErrorCode::InternalError; + return Err((code, msg)); + } + }; + + Ok(is_end) + } + + async fn judge_end_of_absolute_range( + &self, + received_object: &CacheObject, + subgroup_group_id: &Option, + ) -> Result { + let upstream_subscription = self.upstream_subscription.as_ref().unwrap(); + let (end_group, end_object) = upstream_subscription.get_absolute_end(); + let end_group = end_group.unwrap(); + let end_object = end_object.unwrap(); + + let is_end = match received_object { + CacheObject::Track(object_stream_track) => { + let is_group_end = object_stream_track.group_id() == end_group; + let is_object_end = object_stream_track.object_id() == end_object; + let is_ending = is_group_end && is_object_end; + + let is_ended = object_stream_track.group_id() > end_group; + + is_ending || is_ended + } + CacheObject::Subgroup(object_stream_subgroup) => { + let subgroup_group_id = subgroup_group_id.unwrap(); + let is_group_end = subgroup_group_id == end_group; + let is_object_end = object_stream_subgroup.object_id() == end_object; + let is_ending = is_group_end && is_object_end; + + let is_ended = subgroup_group_id > end_group; + + is_ending || is_ended + } + _ => { + let msg = "received object not matched".to_string(); + let code = TerminationErrorCode::InternalError; + return Err((code, msg)); + } + }; + + Ok(is_end) + } } diff --git a/moqt-server/src/modules/server_processes/thread_starters.rs b/moqt-server/src/modules/server_processes/thread_starters.rs index bf0c4c5..9eb216e 100644 --- a/moqt-server/src/modules/server_processes/thread_starters.rs +++ b/moqt-server/src/modules/server_processes/thread_starters.rs @@ -106,26 +106,22 @@ async fn spawn_uni_recv_stream_thread( tokio::spawn( async move { let stream = UniRecvStream::new(stable_id, stream_id, recv_stream); - let mut uni_stream_receiver = UniStreamReceiver::init(stream, client) - .instrument(session_span) - .await; - - let (code, reason) = match uni_stream_receiver.start().await { - Ok(_) => { - let code = TerminationErrorCode::NoError; - let reason = "ObjectStreamForwarder: Finished".to_string(); - tracing::info!(reason); + let senders = client.lock().await.senders(); + let mut uni_stream_receiver = UniStreamReceiver::init(stream, client).await; - (code, reason) - } + match uni_stream_receiver.start().instrument(session_span).await { + Ok(_) => {} Err((code, reason)) => { tracing::error!(reason); - (code, reason) + let _ = senders + .close_session_tx() + .send((u8::from(code) as u64, reason.to_string())) + .await; } - }; + } - let _ = uni_stream_receiver.terminate(code, reason).await; + let _ = uni_stream_receiver.terminate().await; } .in_current_span(), );