diff --git a/README.md b/README.md index 38a75e1..b62139b 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ Supported version: draft-ietf-moq-transport-06 - [ ] TRACK_STATUS - [x] SUBSCRIBE_NAMESPACE_OK - [x] SUBSCRIBE_NAMESPACE_ERROR -- [ ] Data Streams - - [ ] Object Datagram Message +- [x] Data Streams + - [x] Object Datagram Message - [x] Track Stream - [x] Subgroup Stream - [ ] Features diff --git a/js/index.html b/js/index.html index 5c21b28..44b4214 100644 --- a/js/index.html +++ b/js/index.html @@ -17,6 +17,12 @@

Connection


+

Forwarding Preference

+ + + +
+ *You must select at first

Message


@@ -73,10 +79,9 @@

SUBSCRIBE

OBJECT

- - -
-

Header

+
+

Header

+

@@ -89,12 +94,14 @@

Header


-

Object

-
- Group ID:

0

+
+

Object

+
+
+ Group ID:

0


- - + +

@@ -104,6 +111,9 @@

Object


+
diff --git a/js/main.js b/js/main.js index 9ae7bbe..e46d2c1 100644 --- a/js/main.js +++ b/js/main.js @@ -6,7 +6,7 @@ init().then(async () => { let headerSend = false let objectId = 0n - let trackGroupId = 0n + let mutableGroupId = 0n const connectBtn = document.getElementById('connectBtn') connectBtn.addEventListener('click', async () => { @@ -57,7 +57,8 @@ init().then(async () => { if (isSuccess) { let expire = 0n - await client.sendSubscribeOkMessage(receivedSubscribeId, expire, authInfo) + const forwardingPreference = Array.from(form['forwarding-preference']).filter((elem) => elem.checked)[0].value + await client.sendSubscribeOkMessage(receivedSubscribeId, expire, authInfo, forwardingPreference) } else { // TODO: set accurate reasonPhrase let reasonPhrase = 'subscribe error' @@ -73,6 +74,11 @@ init().then(async () => { console.log({ subscribeNamespaceResponse }) }) + client.onObjectDatagram(async (objectDatagram) => { + console.log({ objectDatagram }) + describeReceivedObject(objectDatagram.object_payload) + }) + client.onStreamHeaderTrack(async (streamHeaderTrack) => { console.log({ streamHeaderTrack }) }) @@ -92,7 +98,7 @@ init().then(async () => { }) const objectIdElement = document.getElementById('objectId') - const trackGroupIdElement = document.getElementById('trackGroupId') + const mutableGroupIdElement = document.getElementById('mutableGroupId') const sendSetupBtn = document.getElementById('sendSetupBtn') sendSetupBtn.addEventListener('click', async () => { @@ -155,6 +161,28 @@ init().then(async () => { ) }) + const sendDatagramObjectBtn = document.getElementById('sendDatagramObjectBtn') + sendDatagramObjectBtn.addEventListener('click', async () => { + console.log('send datagram object btn clicked') + const subscribeId = form['object-subscribe-id'].value + const trackAlias = form['object-track-alias'].value + const publisherPriority = form['publisher-priority'].value + const objectPayloadString = form['object-payload'].value + + // encode the text to the object array + const objectPayloadArray = new TextEncoder().encode(objectPayloadString) + + await client.sendObjectDatagram( + BigInt(subscribeId), + BigInt(trackAlias), + mutableGroupId, + objectId++, + publisherPriority, + objectPayloadArray + ) + objectIdElement.textContent = objectId + }) + const sendTrackObjectBtn = document.getElementById('sendTrackObjectBtn') sendTrackObjectBtn.addEventListener('click', async () => { console.log('send track stream object btn clicked') @@ -172,7 +200,7 @@ init().then(async () => { headerSend = true } - await client.sendObjectStreamTrack(BigInt(subscribeId), trackGroupId, objectId++, objectPayloadArray) + await client.sendObjectStreamTrack(BigInt(subscribeId), mutableGroupId, objectId++, objectPayloadArray) objectIdElement.textContent = objectId }) @@ -205,50 +233,67 @@ init().then(async () => { objectIdElement.textContent = objectId }) - const ascendTrackGroupBtn = document.getElementById('ascendTrackGroupIdBtn') - ascendTrackGroupBtn.addEventListener('click', async () => { - trackGroupId++ + const ascendMutableGroupId = document.getElementById('ascendMutableGroupIdBtn') + ascendMutableGroupId.addEventListener('click', async () => { + mutableGroupId++ objectId = 0n - console.log('ascend trackGroupId', trackGroupId) + console.log('ascend mutableGroupId', mutableGroupId) - trackGroupIdElement.textContent = trackGroupId + mutableGroupIdElement.textContent = mutableGroupId objectIdElement.textContent = objectId }) - const descendTrackGroupBtn = document.getElementById('descendTrackGroupIdBtn') - descendTrackGroupBtn.addEventListener('click', async () => { - if (trackGroupId === 0n) { + const descendMutableGroupId = document.getElementById('descendMutableGroupIdBtn') + descendMutableGroupId.addEventListener('click', async () => { + if (mutableGroupId === 0n) { return } - trackGroupId-- + mutableGroupId-- objectId = 0n - console.log('descend trackGroupId', trackGroupId) - trackGroupIdElement.textContent = trackGroupId + console.log('descend mutableGroupId', mutableGroupId) + mutableGroupIdElement.textContent = mutableGroupId objectIdElement.textContent = objectId }) await client.start() }) - const dataStreamType = document.querySelectorAll('input[name="data-stream-type"]') + const forwardingPreference = document.querySelectorAll('input[name="forwarding-preference"]') const subgroupHeaderContents = document.getElementById('subgroupHeaderContents') - const trackObjectContents = document.getElementById('trackObjectContents') + const notSubgroupObjectContents = document.getElementById('notSubgroupObjectContents') + const sendDatagramObject = document.getElementById('sendDatagramObject') const sendTrackObject = document.getElementById('sendTrackObject') const sendSubgroupObject = document.getElementById('sendSubgroupObject') + const headerField = document.getElementById('headerField') + const objectField = document.getElementById('objectField') // change ui within track/subgroup - dataStreamType.forEach((elem) => { + forwardingPreference.forEach((elem) => { elem.addEventListener('change', async () => { - if (elem.value === 'track') { - trackObjectContents.style.display = 'block' + if (elem.value === 'datagram') { + notSubgroupObjectContents.style.display = 'block' + subgroupHeaderContents.style.display = 'none' + sendDatagramObject.style.display = 'block' + sendTrackObject.style.display = 'none' + sendSubgroupObject.style.display = 'none' + headerField.style.display = 'none' + objectField.style.display = 'none' + } else if (elem.value === 'track') { + notSubgroupObjectContents.style.display = 'block' subgroupHeaderContents.style.display = 'none' + sendDatagramObject.style.display = 'none' sendTrackObject.style.display = 'block' sendSubgroupObject.style.display = 'none' + headerField.style.display = 'block' + objectField.style.display = 'block' } else if (elem.value === 'subgroup') { - trackObjectContents.style.display = 'none' + notSubgroupObjectContents.style.display = 'none' subgroupHeaderContents.style.display = 'block' + sendDatagramObject.style.display = 'none' sendTrackObject.style.display = 'none' sendSubgroupObject.style.display = 'block' + headerField.style.display = 'block' + objectField.style.display = 'block' } }) }) diff --git a/moqt-client-sample/Cargo.toml b/moqt-client-sample/Cargo.toml index 4952189..7c472b8 100644 --- a/moqt-client-sample/Cargo.toml +++ b/moqt-client-sample/Cargo.toml @@ -27,6 +27,7 @@ anyhow = "1.0.75" version = "0.3.64" features = [ 'WebTransport', + "WebTransportDatagramDuplexStream", 'WebTransportBidirectionalStream', 'WebTransportSendStream', 'WebTransportReceiveStream', diff --git a/moqt-client-sample/src/lib.rs b/moqt-client-sample/src/lib.rs index 5cf205f..4323673 100644 --- a/moqt-client-sample/src/lib.rs +++ b/moqt-client-sample/src/lib.rs @@ -27,9 +27,9 @@ use moqt_core::{ }, messages::{ data_streams::{ - object_stream_subgroup::ObjectStreamSubgroup, object_stream_track::ObjectStreamTrack, - stream_header_subgroup::StreamHeaderSubgroup, stream_header_track::StreamHeaderTrack, - DataStreams, + object_datagram::ObjectDatagram, object_stream_subgroup::ObjectStreamSubgroup, + object_stream_track::ObjectStreamTrack, stream_header_subgroup::StreamHeaderSubgroup, + stream_header_track::StreamHeaderTrack, DataStreams, }, moqt_payload::MOQTPayload, }, @@ -79,6 +79,7 @@ pub struct MOQTClient { subscription_node: Rc>, transport: Rc>>, control_stream_writer: Rc>>, + object_datagram_writer: Rc>>, object_stream_writers: Rc>>, callbacks: Rc>, } @@ -94,6 +95,7 @@ impl MOQTClient { subscription_node: Rc::new(RefCell::new(SubscriptionNode::new())), transport: Rc::new(RefCell::new(None)), control_stream_writer: Rc::new(RefCell::new(None)), + object_datagram_writer: Rc::new(RefCell::new(None)), object_stream_writers: Rc::new(RefCell::new(HashMap::new())), callbacks: Rc::new(RefCell::new(MOQTCallbacks::new())), } @@ -137,6 +139,13 @@ impl MOQTClient { .set_subscribe_namespace_response_callback(callback); } + #[wasm_bindgen(js_name = onObjectDatagram)] + pub fn set_object_datagram_callback(&mut self, callback: js_sys::Function) { + self.callbacks + .borrow_mut() + .set_object_datagram_callback(callback); + } + #[wasm_bindgen(js_name = onStreamHeaderTrack)] pub fn set_stream_header_track_callback(&mut self, callback: js_sys::Function) { self.callbacks @@ -487,6 +496,7 @@ impl MOQTClient { subscribe_id: u64, expires: u64, auth_info: String, + fowarding_preference: String, ) -> Result { if let Some(writer) = &*self.control_stream_writer.borrow() { let auth_info = @@ -538,21 +548,38 @@ impl MOQTClient { .borrow_mut() .activate_as_publisher(subscribe_id); - let send_uni_stream = web_sys::WritableStream::from( - JsFuture::from( - self.transport + match &*fowarding_preference { + // stream + "datagram" => { + let datagram_writer = self + .transport .borrow() .as_ref() .unwrap() - .create_unidirectional_stream(), - ) - .await?, - ); - let send_uni_stream_writer = send_uni_stream.get_writer()?; - - self.object_stream_writers - .borrow_mut() - .insert(subscribe_id, send_uni_stream_writer); + .datagrams() + .writable() + .get_writer()?; + *self.object_datagram_writer.borrow_mut() = Some(datagram_writer); + } + "track" | "subgroup" => { + let send_uni_stream = web_sys::WritableStream::from( + JsFuture::from( + self.transport + .borrow() + .as_ref() + .unwrap() + .create_unidirectional_stream(), + ) + .await?, + ); + let send_uni_stream_writer = send_uni_stream.get_writer()?; + + self.object_stream_writers + .borrow_mut() + .insert(subscribe_id, send_uni_stream_writer); + } + _ => {} + } Ok(ok) } @@ -701,6 +728,54 @@ impl MOQTClient { } } + #[wasm_bindgen(js_name = sendObjectDatagram)] + pub async fn send_object_datagram( + &self, + subscribe_id: u64, + track_alias: u64, + group_id: u64, + object_id: u64, + publisher_priority: u8, + object_payload: Vec, + ) -> Result { + if let Some(writer) = &*self.object_datagram_writer.borrow() { + let object_datagram = ObjectDatagram::new( + subscribe_id, + track_alias, + group_id, + object_id, + publisher_priority, + None, + object_payload, + ) + .unwrap(); + let mut object_datagram_buf = BytesMut::new(); + let _ = object_datagram.packetize(&mut object_datagram_buf); + + let mut buf = Vec::new(); + // Message Type + buf.extend(write_variable_integer( + u8::from(DataStreamType::ObjectDatagram) as u64, + )); + buf.extend(object_datagram_buf); + + let buffer = js_sys::Uint8Array::new_with_length(buf.len() as u32); + buffer.copy_from(&buf); + match JsFuture::from(writer.write_with_chunk(&buffer)).await { + Ok(ok) => { + log(std::format!("sent: object id: {:#?}", object_id).as_str()); + Ok(ok) + } + Err(e) => { + log(std::format!("err: {:?}", e).as_str()); + Err(e) + } + } + } else { + return Err(JsValue::from_str("object_datagram_writer is None")); + } + } + #[wasm_bindgen(js_name = sendStreamHeaderTrackMessage)] pub async fn send_stream_header_track_message( &self, @@ -889,7 +964,15 @@ impl MOQTClient { .await; }); - // For receiving object messages + // // For receiving object messages as datagrams + let datagram_reader_readable = transport.datagrams().readable(); + let datagram_reader = web_sys::ReadableStreamDefaultReader::new(&datagram_reader_readable)?; + let callbacks = self.callbacks.clone(); + wasm_bindgen_futures::spawn_local(async move { + let _ = datagram_read_thread(callbacks, &datagram_reader).await; + }); + + // For receiving object messages as streams let incoming_uni_stream = transport.incoming_unidirectional_streams(); let incoming_uni_stream_reader = web_sys::ReadableStreamDefaultReader::new(&&incoming_uni_stream.into())?; @@ -1175,6 +1258,43 @@ async fn control_message_handler( Ok(()) } +#[cfg(web_sys_unstable_apis)] +async fn datagram_read_thread( + callbacks: Rc>, + reader: &ReadableStreamDefaultReader, +) -> Result<(), JsValue> { + log("datagram_read_thread"); + + let mut buf = BytesMut::new(); + + loop { + let ret = reader.read(); + let ret = JsFuture::from(ret).await?; + + let ret_value = js_sys::Reflect::get(&ret, &JsValue::from_str("value"))?; + let ret_done = js_sys::Reflect::get(&ret, &JsValue::from_str("done"))?; + let ret_done = js_sys::Boolean::from(ret_done).value_of(); + + if ret_done { + break; + } + + let ret_value = js_sys::Uint8Array::from(ret_value).to_vec(); + + for i in ret_value { + buf.put_u8(i); + } + + while buf.len() > 0 { + if let Err(e) = datagram_handler(callbacks.clone(), &mut buf).await { + log(std::format!("error: {:#?}", e).as_str()); + break; + } + } + } + Ok(()) +} + #[cfg(web_sys_unstable_apis)] async fn uni_directional_stream_read_thread( callbacks: Rc>, @@ -1236,7 +1356,7 @@ async fn uni_directional_stream_read_thread( object_stream_subgroup_handler(callbacks.clone(), &mut buf).await { log(std::format!("error: {:#?}", e).as_str()); - return Err(js_sys::Error::new(&e.to_string()).into()); + break; } } } @@ -1313,6 +1433,53 @@ async fn object_header_handler( Ok(data_stream_type) } +#[cfg(web_sys_unstable_apis)] +async fn datagram_handler(callbacks: Rc>, buf: &mut BytesMut) -> Result<()> { + let mut read_cur = Cursor::new(&buf[..]); + let header_type_value = read_variable_integer(&mut read_cur); + + match header_type_value { + Ok(v) => { + let data_stream_type = DataStreamType::try_from(v as u8)?; + + log(std::format!("data_stream_type_value: {:#x?}", data_stream_type).as_str()); + + if data_stream_type == DataStreamType::ObjectDatagram { + let object_datagram = match ObjectDatagram::depacketize(&mut read_cur) { + Ok(v) => { + log(std::format!("object_id: {:#?}", v.object_id()).as_str()); + buf.advance(read_cur.position() as usize); + v + } + Err(e) => { + read_cur.set_position(0); + log(std::format!("retry because: {:#?}", e).as_str()); + return Err(e); + } + }; + + if let Some(callback) = callbacks.borrow().object_datagram_callback() { + callback + .call1(&JsValue::null(), &JsValue::from("called2")) + .unwrap(); + let v = serde_wasm_bindgen::to_value(&object_datagram).unwrap(); + callback.call1(&JsValue::null(), &(v)).unwrap(); + } + } else { + let msg = "format error".to_string(); + log(std::format!("{}", msg).as_str()); + return Err(anyhow::anyhow!(msg)); + } + } + Err(e) => { + log("data_stream_type_value is None"); + return Err(e); + } + } + + Ok(()) +} + #[cfg(web_sys_unstable_apis)] async fn object_stream_track_handler( callbacks: Rc>, @@ -1561,6 +1728,7 @@ struct MOQTCallbacks { subscribe_callback: Option, subscribe_response_callback: Option, subscribe_namespace_response_callback: Option, + object_datagram_callback: Option, stream_header_track_callback: Option, object_stream_track_callback: Option, stream_header_subgroup_callback: Option, @@ -1577,6 +1745,7 @@ impl MOQTCallbacks { subscribe_callback: None, subscribe_response_callback: None, subscribe_namespace_response_callback: None, + object_datagram_callback: None, stream_header_track_callback: None, object_stream_track_callback: None, stream_header_subgroup_callback: None, @@ -1632,6 +1801,14 @@ impl MOQTCallbacks { self.subscribe_namespace_response_callback = Some(callback); } + pub fn object_datagram_callback(&self) -> Option { + self.object_datagram_callback.clone() + } + + pub fn set_object_datagram_callback(&mut self, callback: js_sys::Function) { + self.object_datagram_callback = Some(callback); + } + pub fn stream_header_track_callback(&self) -> Option { self.stream_header_track_callback.clone() } diff --git a/moqt-core/src/modules/messages/data_streams/object_datagram.rs b/moqt-core/src/modules/messages/data_streams/object_datagram.rs index 6322e02..b6b05eb 100644 --- a/moqt-core/src/modules/messages/data_streams/object_datagram.rs +++ b/moqt-core/src/modules/messages/data_streams/object_datagram.rs @@ -56,6 +56,10 @@ impl ObjectDatagram { }) } + pub fn subscribe_id(&self) -> u64 { + self.subscribe_id + } + pub fn track_alias(&self) -> u64 { self.track_alias } @@ -67,6 +71,18 @@ impl ObjectDatagram { pub fn object_id(&self) -> u64 { self.object_id } + + pub fn publisher_priority(&self) -> u8 { + self.publisher_priority + } + + pub fn object_status(&self) -> Option { + self.object_status + } + + pub fn object_payload(&self) -> Vec { + self.object_payload.clone() + } } impl DataStreams for ObjectDatagram { diff --git a/moqt-server/src/modules/message_handlers.rs b/moqt-server/src/modules/message_handlers.rs index 1ac29f5..34ed165 100644 --- a/moqt-server/src/modules/message_handlers.rs +++ b/moqt-server/src/modules/message_handlers.rs @@ -1,3 +1,4 @@ pub(crate) mod control_message; +pub(crate) mod object_datagram; pub(crate) mod object_stream; pub(crate) mod stream_header; diff --git a/moqt-server/src/modules/message_handlers/object_datagram.rs b/moqt-server/src/modules/message_handlers/object_datagram.rs new file mode 100644 index 0000000..fcd433a --- /dev/null +++ b/moqt-server/src/modules/message_handlers/object_datagram.rs @@ -0,0 +1,172 @@ +use crate::constants::TerminationErrorCode; +use crate::modules::{ + moqt_client::{MOQTClient, MOQTClientStatus}, + object_cache_storage::{CacheHeader, CacheObject, ObjectCacheStorageWrapper}, +}; +use anyhow::{bail, Result}; +use bytes::{Buf, BytesMut}; +use moqt_core::{ + data_stream_type::DataStreamType, + messages::data_streams::{object_datagram::ObjectDatagram, DataStreams}, + models::tracks::ForwardingPreference, + variable_integer::read_variable_integer, + PubSubRelationManagerRepository, +}; +use std::{io::Cursor, sync::Arc}; +use tokio::sync::Mutex; + +#[derive(Debug, PartialEq)] +pub enum ObjectDatagramProcessResult { + Success((CacheObject, bool)), + IncompleteMessage, + Failure(TerminationErrorCode, String), +} + +fn read_header_type(read_cur: &mut std::io::Cursor<&[u8]>) -> Result { + let type_value = match read_variable_integer(read_cur) { + Ok(v) => v as u8, + Err(err) => { + bail!(err.to_string()); + } + }; + + let header_type: DataStreamType = match DataStreamType::try_from(type_value) { + Ok(v) => { + if v == DataStreamType::StreamHeaderTrack || v == DataStreamType::StreamHeaderSubgroup { + bail!("{:?} is not header type", v); + } + v + } + Err(err) => { + bail!(err.to_string()); + } + }; + Ok(header_type) +} + +pub async fn object_datagram_handler( + read_buf: &mut BytesMut, + client: Arc>, + pubsub_relation_manager_repository: &mut dyn PubSubRelationManagerRepository, + object_cache_storage: &mut ObjectCacheStorageWrapper, +) -> ObjectDatagramProcessResult { + let payload_length = read_buf.len(); + tracing::trace!("object_datagram_handler! {}", payload_length); + + // Check if the data is exist + if payload_length == 0 { + return ObjectDatagramProcessResult::IncompleteMessage; + } + + // TODO: Set the accurate duration + let duration = 100000; + + let mut read_cur = Cursor::new(&read_buf[..]); + + // Read the header type + let header_type = match read_header_type(&mut read_cur) { + Ok(v) => v, + Err(err) => { + read_buf.advance(read_cur.position() as usize); + + tracing::error!("header_type is wrong: {:?}", err); + return ObjectDatagramProcessResult::Failure( + TerminationErrorCode::ProtocolViolation, + err.to_string(), + ); + } + }; + + let client_status: MOQTClientStatus; + let upstream_session_id: usize; + { + let client = client.lock().await; + client_status = client.status(); + upstream_session_id = client.id(); + } + // check subscription and judge if it is invalid timing + if client_status != MOQTClientStatus::SetUp { + let message = String::from("Invalid timing"); + tracing::error!(message); + return ObjectDatagramProcessResult::Failure( + TerminationErrorCode::ProtocolViolation, + message, + ); + } + + tracing::debug!("object_stream: read_buf: {:?}", read_buf); + + match header_type { + DataStreamType::ObjectDatagram => { + let result = ObjectDatagram::depacketize(&mut read_cur); + match result { + Ok(object) => { + read_buf.advance(read_cur.position() as usize); + + let upstream_subscribe_id = object.subscribe_id(); + + let is_first_time = match object_cache_storage + .get_header(upstream_session_id, upstream_subscribe_id) + .await + { + Ok(CacheHeader::Datagram) => { + // It's not first time to receive datagram + false + } + Err(_) => { + // It's first time to receive datagram + let _ = pubsub_relation_manager_repository + .set_upstream_forwarding_preference( + upstream_session_id, + upstream_subscribe_id, + ForwardingPreference::Datagram, + ) + .await; + + let _ = object_cache_storage + .set_subscription( + upstream_session_id, + upstream_subscribe_id, + CacheHeader::Datagram, + ) + .await; + + true + } + _ => { + let msg = "failed to get cache header, error: unexpected cache header is already set" + .to_string(); + tracing::error!(msg); + return ObjectDatagramProcessResult::Failure( + TerminationErrorCode::InternalError, + msg, + ); + } + }; + + let received_object = CacheObject::Datagram(object); + object_cache_storage + .set_object( + upstream_session_id, + upstream_subscribe_id, + received_object.clone(), + duration, + ) + .await + .unwrap(); + + ObjectDatagramProcessResult::Success((received_object, is_first_time)) + } + Err(err) => { + tracing::warn!("{:#?}", err); + read_cur.set_position(0); + ObjectDatagramProcessResult::IncompleteMessage + } + } + } + _ => ObjectDatagramProcessResult::Failure( + TerminationErrorCode::ProtocolViolation, + format!("Invalid message type: {:?}", header_type), + ), + } +} diff --git a/moqt-server/src/modules/message_handlers/object_stream.rs b/moqt-server/src/modules/message_handlers/object_stream.rs index d2b19eb..aec4044 100644 --- a/moqt-server/src/modules/message_handlers/object_stream.rs +++ b/moqt-server/src/modules/message_handlers/object_stream.rs @@ -72,6 +72,7 @@ pub async fn object_stream_handler( Err(err) => { tracing::warn!("{:#?}", err); read_cur.set_position(0); + ObjectStreamProcessResult::IncompleteMessage } } @@ -93,6 +94,7 @@ pub async fn object_stream_handler( Err(err) => { tracing::warn!("{:#?}", err); read_cur.set_position(0); + ObjectStreamProcessResult::IncompleteMessage } } diff --git a/moqt-server/src/modules/moqt_client.rs b/moqt-server/src/modules/moqt_client.rs index 4db229b..510540f 100644 --- a/moqt-server/src/modules/moqt_client.rs +++ b/moqt-server/src/modules/moqt_client.rs @@ -14,7 +14,6 @@ pub struct MOQTClient { id: usize, status: MOQTClientStatus, role: Option, - senders: Arc, } diff --git a/moqt-server/src/modules/server_processes/session_handler.rs b/moqt-server/src/modules/server_processes/session_handler.rs index c93e734..ce9e254 100644 --- a/moqt-server/src/modules/server_processes/session_handler.rs +++ b/moqt-server/src/modules/server_processes/session_handler.rs @@ -21,7 +21,7 @@ use tracing::{self}; use wtransport::{endpoint::IncomingSession, Connection}; pub(crate) struct SessionHandler { - session: Connection, + session: Arc, client: Arc>, close_session_rx: mpsc::Receiver<(u64, String)>, open_downstream_stream_or_datagram_rx: mpsc::Receiver<(u64, DataStreamType)>, @@ -86,6 +86,7 @@ impl SessionHandler { .insert(stable_id, open_downstream_stream_or_datagram_tx); let client = Arc::new(Mutex::new(MOQTClient::new(stable_id, senders))); + let session = Arc::new(session); let session_handler = SessionHandler { session, @@ -103,7 +104,7 @@ impl SessionHandler { loop { match select_spawn_thread( &self.client, - &self.session, + self.session.clone(), &mut self.open_downstream_stream_or_datagram_rx, &mut self.close_session_rx, &mut is_control_stream_opened, diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram.rs b/moqt-server/src/modules/server_processes/stream_and_datagram.rs index 03fce4d..9e07c72 100644 --- a/moqt-server/src/modules/server_processes/stream_and_datagram.rs +++ b/moqt-server/src/modules/server_processes/stream_and_datagram.rs @@ -1,2 +1,3 @@ pub(crate) mod bi_directional_stream; +pub(crate) mod datagram; pub(crate) mod uni_directional_stream; diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram/datagram.rs b/moqt-server/src/modules/server_processes/stream_and_datagram/datagram.rs new file mode 100644 index 0000000..1360d95 --- /dev/null +++ b/moqt-server/src/modules/server_processes/stream_and_datagram/datagram.rs @@ -0,0 +1,2 @@ +pub(crate) mod forwarder; +pub(crate) mod receiver; diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram/datagram/forwarder.rs b/moqt-server/src/modules/server_processes/stream_and_datagram/datagram/forwarder.rs new file mode 100644 index 0000000..7552cbb --- /dev/null +++ b/moqt-server/src/modules/server_processes/stream_and_datagram/datagram/forwarder.rs @@ -0,0 +1,340 @@ +use crate::modules::{ + buffer_manager::BufferCommand, + moqt_client::MOQTClient, + object_cache_storage::{CacheObject, ObjectCacheStorageWrapper}, + pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, + server_processes::senders::Senders, +}; +use anyhow::{bail, Result}; +use bytes::BytesMut; +use moqt_core::{ + data_stream_type::DataStreamType, + messages::{ + control_messages::subscribe::FilterType, + data_streams::{object_datagram::ObjectDatagram, object_status::ObjectStatus, DataStreams}, + }, + models::{subscriptions::Subscription, tracks::ForwardingPreference}, + pubsub_relation_manager_repository::PubSubRelationManagerRepository, + variable_integer::write_variable_integer, +}; +use std::{sync::Arc, thread, time::Duration}; +use tokio::sync::Mutex; +use tracing::{self}; +use wtransport::Connection; + +struct ObjectCacheKey { + session_id: usize, + subscribe_id: u64, +} + +impl ObjectCacheKey { + fn new(session_id: usize, subscribe_id: u64) -> Self { + ObjectCacheKey { + session_id, + subscribe_id, + } + } + + fn session_id(&self) -> usize { + self.session_id + } + + fn subscribe_id(&self) -> u64 { + self.subscribe_id + } +} + +pub(crate) struct DatagramForwarder { + session: Arc, + senders: Arc, + downstream_subscribe_id: u64, + downstream_subscription: Subscription, + object_cache_key: ObjectCacheKey, + sleep_time: Duration, +} + +impl DatagramForwarder { + pub(crate) async fn init( + session: Arc, + downstream_subscribe_id: u64, + client: Arc>, + ) -> Result { + let senders = client.lock().await.senders(); + let sleep_time = Duration::from_millis(10); + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); + + let downstream_session_id = session.stable_id(); + + let downstream_subscription = pubsub_relation_manager + .get_downstream_subscription_by_ids(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); + + // Get the information of the original publisher who has the track being requested + let (upstream_session_id, upstream_subscribe_id) = pubsub_relation_manager + .get_related_publisher(downstream_session_id, downstream_subscribe_id) + .await?; + + let object_cache_key = ObjectCacheKey::new(upstream_session_id, upstream_subscribe_id); + + let datagram_forwarder = DatagramForwarder { + session, + senders, + downstream_subscribe_id, + downstream_subscription, + object_cache_key, + sleep_time, + }; + + Ok(datagram_forwarder) + } + + pub(crate) async fn start(&mut self) -> Result<()> { + let mut object_cache_storage = + ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); + + self.check_and_set_forwarding_preference().await?; + + self.forward_loop(&mut object_cache_storage).await?; + + Ok(()) + } + + pub(crate) async fn terminate(&self) -> Result<()> { + let downstream_session_id = self.session.stable_id(); + let downstream_stream_id = 0; // stream_id of datagram does not exist (TODO: delete buffer manager) + self.senders + .buffer_tx() + .send(BufferCommand::ReleaseStream { + session_id: downstream_session_id, + stream_id: downstream_stream_id, + }) + .await?; + + tracing::info!("Terminated DatagramForwarder"); + + Ok(()) + } + + async fn check_and_set_forwarding_preference(&self) -> Result<()> { + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + + let downstream_session_id = self.session.stable_id(); + let downstream_subscribe_id = self.downstream_subscribe_id; + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); + + match pubsub_relation_manager + .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) + .await? + { + Some(ForwardingPreference::Datagram) => { + pubsub_relation_manager + .set_downstream_forwarding_preference( + downstream_session_id, + downstream_subscribe_id, + ForwardingPreference::Datagram, + ) + .await?; + } + _ => { + let msg = "Invalid forwarding preference"; + bail!(msg); + } + }; + + Ok(()) + } + + async fn forward_loop( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<()> { + let mut object_cache_id = None; + let mut is_end = false; + + loop { + if is_end { + break; + } + + (object_cache_id, is_end) = self + .get_and_forward_object(object_cache_storage, object_cache_id) + .await?; + } + + Ok(()) + } + + async fn get_and_forward_object( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + cache_id: Option, + ) -> Result<(Option, bool)> { + // Do loop until get an object from the cache storage + loop { + let cache = match cache_id { + None => self.try_to_get_first_object(object_cache_storage).await?, + Some(cache_id) => { + self.try_to_get_subsequent_object(object_cache_storage, cache_id) + .await? + } + }; + + match cache { + None => { + // If there is no object in the cache storage, sleep for a while and try again + thread::sleep(self.sleep_time); + continue; + } + Some((cache_id, cache_object)) => { + self.packetize_and_forward_object(&cache_object).await?; + + let is_end = self.judge_end_of_forwarding(&cache_object).await?; + + return Ok((Some(cache_id), is_end)); + } + } + } + } + + async fn try_to_get_first_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result> { + let filter_type = self.downstream_subscription.get_filter_type(); + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); + + match filter_type { + FilterType::LatestGroup => { + object_cache_storage + .get_latest_group(upstream_session_id, upstream_subscribe_id) + .await + } + FilterType::LatestObject => { + object_cache_storage + .get_latest_object(upstream_session_id, upstream_subscribe_id) + .await + } + FilterType::AbsoluteStart | FilterType::AbsoluteRange => { + let (start_group, start_object) = self.downstream_subscription.get_absolute_start(); + + object_cache_storage + .get_absolute_object( + upstream_session_id, + upstream_subscribe_id, + start_group.unwrap(), + start_object.unwrap(), + ) + .await + } + } + } + + async fn try_to_get_subsequent_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + object_cache_id: usize, + ) -> Result> { + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); + + object_cache_storage + .get_next_object(upstream_session_id, upstream_subscribe_id, object_cache_id) + .await + } + + async fn packetize_and_forward_object(&mut self, cache_object: &CacheObject) -> Result<()> { + let message_buf = match cache_object { + CacheObject::Datagram(object_datagram) => { + self.packetize_forwarding_object_datagram(object_datagram) + .await + } + _ => { + let msg = "cache object not matched"; + bail!(msg) + } + }; + + if let Err(e) = self.session.send_datagram(&message_buf) { + tracing::warn!("Failed to send datagram: {:?}", e); + bail!(e); + } + + Ok(()) + } + + async fn judge_end_of_forwarding(&self, cache_object: &CacheObject) -> Result { + let is_end_of_data_stream = self.judge_end_of_data_stream(cache_object).await?; + if is_end_of_data_stream { + return Ok(true); + } + + let filter_type = self.downstream_subscription.get_filter_type(); + if filter_type == FilterType::AbsoluteRange { + let is_end_of_absolute_range = self.judge_end_of_absolute_range(cache_object).await?; + if is_end_of_absolute_range { + return Ok(true); + } + } + + Ok(false) + } + + async fn judge_end_of_data_stream(&self, cache_object: &CacheObject) -> Result { + let is_end = match cache_object { + CacheObject::Datagram(object_datagram) => { + matches!( + object_datagram.object_status(), + Some(ObjectStatus::EndOfTrackAndGroup) + ) + } + _ => { + let msg = "cache object not matched"; + bail!(msg) + } + }; + + Ok(is_end) + } + + async fn judge_end_of_absolute_range(&self, cache_object: &CacheObject) -> Result { + let (end_group, end_object) = self.downstream_subscription.get_absolute_end(); + let end_group = end_group.unwrap(); + let end_object = end_object.unwrap(); + + let is_end = match cache_object { + CacheObject::Datagram(object_datagram) => { + let is_group_end = object_datagram.group_id() == end_group; + let is_object_end = object_datagram.object_id() == end_object; + let is_ending = is_group_end && is_object_end; + + let is_ended = object_datagram.group_id() > end_group; + + is_ending || is_ended + } + _ => { + let msg = "cache object not matched"; + bail!(msg) + } + }; + + Ok(is_end) + } + + async fn packetize_forwarding_object_datagram(&self, object: &ObjectDatagram) -> BytesMut { + let mut buf = BytesMut::new(); + object.packetize(&mut buf); + + let mut message_buf = BytesMut::with_capacity(buf.len()); + message_buf.extend(write_variable_integer( + u8::from(DataStreamType::ObjectDatagram) as u64, + )); + message_buf.extend(buf); + + message_buf + } +} diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram/datagram/receiver.rs b/moqt-server/src/modules/server_processes/stream_and_datagram/datagram/receiver.rs new file mode 100644 index 0000000..a91aeba --- /dev/null +++ b/moqt-server/src/modules/server_processes/stream_and_datagram/datagram/receiver.rs @@ -0,0 +1,161 @@ +use crate::{ + modules::{ + buffer_manager::request_buffer, + message_handlers::object_datagram::{object_datagram_handler, ObjectDatagramProcessResult}, + moqt_client::MOQTClient, + object_cache_storage::{CacheObject, ObjectCacheStorageWrapper}, + pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, + server_processes::senders::Senders, + }, + TerminationError, +}; +use anyhow::Result; +use bytes::BytesMut; +use moqt_core::{ + constants::TerminationErrorCode, data_stream_type::DataStreamType, + pubsub_relation_manager_repository::PubSubRelationManagerRepository, +}; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{self}; +use wtransport::datagram::Datagram; + +pub(crate) struct DatagramReceiver { + buf: Arc>, + senders: Arc, + client: Arc>, +} + +impl DatagramReceiver { + pub(crate) async fn init(client: Arc>) -> Self { + let senders = client.lock().await.senders(); + let stable_id = client.lock().await.id(); + let stream_id = 0; // stream_id of datagram does not exist (TODO: delete buffer manager) + let buf = request_buffer(senders.buffer_tx().clone(), stable_id, stream_id).await; + + DatagramReceiver { + buf, + senders, + client, + } + } + + pub(crate) async fn receive(&mut self, datagram: Datagram) -> Result<(), TerminationError> { + let mut object_cache_storage = + ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); + + self.add_datagram_to_buf(datagram).await; + + self.read_buf_and_store_to_object_cache(&mut object_cache_storage) + .await?; + + Ok(()) + } + + async fn add_datagram_to_buf(&mut self, datagram: Datagram) { + let datagram_payload = datagram.payload(); + + let read_buf = BytesMut::from(&datagram_payload[..]); + self.buf.lock().await.extend_from_slice(&read_buf); + } + + async fn read_buf_and_store_to_object_cache( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result<(), TerminationError> { + loop { + let is_succeeded = self + .read_object_datagram_and_store_to_object_cache(object_cache_storage) + .await?; + + if !is_succeeded { + break; + } + } + + Ok(()) + } + + async fn read_object_datagram_and_store_to_object_cache( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result { + let result = self + .try_to_read_buf_and_store_to_object_cache(object_cache_storage) + .await; + + match result { + ObjectDatagramProcessResult::Success((received_object, is_first_time)) => { + if is_first_time { + if let CacheObject::Datagram(received_object) = received_object { + match self + .open_downstream_datagram(received_object.subscribe_id()) + .await + { + Ok(_) => {} + Err(err) => { + let msg = format!("Fail to open downstream datagram: {:?}", err); + let code = TerminationErrorCode::InternalError; + + return Err((code, msg)); + } + }; + } + } + Ok(true) + } + + ObjectDatagramProcessResult::IncompleteMessage => Ok(false), + ObjectDatagramProcessResult::Failure(code, reason) => { + let msg = std::format!("object_stream_read failure: {:?}", reason); + tracing::error!(msg); + Err((code, reason)) + } + } + } + + async fn open_downstream_datagram(&self, upstream_subscribe_id: u64) -> Result<()> { + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + let upstream_session_id = self.client.lock().await.id(); + let open_subscription_txes = self.senders.open_downstream_stream_or_datagram_txes(); + let header_type = DataStreamType::ObjectDatagram; + + let subscribers = pubsub_relation_manager + .get_related_subscribers(upstream_session_id, upstream_subscribe_id) + .await + .unwrap(); + + for (downstream_session_id, downstream_subscribe_id) in subscribers { + let open_subscription_tx = open_subscription_txes + .lock() + .await + .get(&downstream_session_id) + .unwrap() + .clone(); + + let _ = open_subscription_tx + .send((downstream_subscribe_id, header_type)) + .await; + } + Ok(()) + } + + async fn try_to_read_buf_and_store_to_object_cache( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> ObjectDatagramProcessResult { + let mut process_buf = self.buf.lock().await; + let client = self.client.clone(); + let mut pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + + object_datagram_handler( + &mut process_buf, + client, + &mut pubsub_relation_manager, + object_cache_storage, + ) + .await + } +} diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/forwarder.rs b/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/forwarder.rs index 8a1c0a8..0447856 100644 --- a/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/forwarder.rs +++ b/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/forwarder.rs @@ -3,20 +3,21 @@ use crate::modules::{ moqt_client::MOQTClient, object_cache_storage::{CacheHeader, CacheObject, ObjectCacheStorageWrapper}, pubsub_relation_manager::wrapper::PubSubRelationManagerWrapper, + server_processes::senders::Senders, }; use anyhow::{bail, Result}; use bytes::BytesMut; use moqt_core::{ - constants::TerminationErrorCode, data_stream_type::DataStreamType, messages::{ control_messages::subscribe::FilterType, data_streams::{ - stream_header_subgroup::StreamHeaderSubgroup, stream_header_track::StreamHeaderTrack, - DataStreams, + object_status::ObjectStatus, object_stream_subgroup::ObjectStreamSubgroup, + object_stream_track::ObjectStreamTrack, stream_header_subgroup::StreamHeaderSubgroup, + stream_header_track::StreamHeaderTrack, DataStreams, }, }, - models::tracks::ForwardingPreference, + models::{subscriptions::Subscription, tracks::ForwardingPreference}, pubsub_relation_manager_repository::PubSubRelationManagerRepository, variable_integer::write_variable_integer, }; @@ -26,174 +27,268 @@ use tracing::{self}; use super::streams::UniSendStream; -pub(crate) async fn forward_object_stream( - stream: &mut UniSendStream, - client: Arc>, +struct ObjectCacheKey { + session_id: usize, + subscribe_id: u64, +} + +impl ObjectCacheKey { + fn new(session_id: usize, subscribe_id: u64) -> Self { + ObjectCacheKey { + session_id, + subscribe_id, + } + } + + fn session_id(&self) -> usize { + self.session_id + } + + fn subscribe_id(&self) -> u64 { + self.subscribe_id + } +} + +pub(crate) struct ObjectStreamForwarder { + stream: UniSendStream, + senders: Arc, + downstream_subscription: Subscription, data_stream_type: DataStreamType, -) -> Result<()> { - let senders = client.lock().await.senders(); - let sleep_time = Duration::from_millis(10); + object_cache_key: ObjectCacheKey, + sleep_time: Duration, +} - let mut subgroup_header: Option = None; // For end group of AbsoluteRange +impl ObjectStreamForwarder { + pub(crate) async fn init( + stream: UniSendStream, + client: Arc>, + data_stream_type: DataStreamType, + ) -> Result { + let senders = client.lock().await.senders(); + let sleep_time = Duration::from_millis(10); + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); + + let downstream_session_id = stream.stable_id(); + let downstream_subscribe_id = stream.subscribe_id(); + + let downstream_subscription = pubsub_relation_manager + .get_downstream_subscription_by_ids(downstream_session_id, downstream_subscribe_id) + .await? + .unwrap(); - let downstream_session_id = stream.stable_id(); - let downstream_stream_id = stream.stream_id(); - let downstream_subscribe_id = stream.subscribe_id(); + // Get the information of the original publisher who has the track being requested + let (upstream_session_id, upstream_subscribe_id) = pubsub_relation_manager + .get_related_publisher(downstream_session_id, downstream_subscribe_id) + .await?; - let pubsub_relation_manager = - PubSubRelationManagerWrapper::new(senders.pubsub_relation_tx().clone()); + let object_cache_key = ObjectCacheKey::new(upstream_session_id, upstream_subscribe_id); - let mut object_cache_storage = - ObjectCacheStorageWrapper::new(senders.object_cache_tx().clone()); + let object_stream_forwarder = ObjectStreamForwarder { + stream, + senders, + downstream_subscription, + data_stream_type, + object_cache_key, + sleep_time, + }; - // Get the information of the original publisher who has the track being requested - let (upstream_session_id, upstream_subscribe_id) = pubsub_relation_manager - .get_related_publisher(downstream_session_id, downstream_subscribe_id) - .await?; - let downstream_subscription = pubsub_relation_manager - .get_downstream_subscription_by_ids(downstream_session_id, downstream_subscribe_id) - .await? - .unwrap(); - let downstream_track_alias = downstream_subscription.get_track_alias(); - let filter_type = downstream_subscription.get_filter_type(); - let (start_group, start_object) = downstream_subscription.get_absolute_start(); - let (end_group, end_object) = downstream_subscription.get_absolute_end(); - - // Validate the forwarding preference as Track - match pubsub_relation_manager - .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) - .await? - { - Some(ForwardingPreference::Track) => { - if data_stream_type != DataStreamType::StreamHeaderTrack { - let msg = std::format!( - "uni send stream's data stream type is wrong (expected Track, but got {:?})", - data_stream_type - ); - senders - .close_session_tx() - .send(( - u8::from(TerminationErrorCode::InternalError) as u64, - msg.clone(), - )) - .await?; - bail!(msg) + Ok(object_stream_forwarder) + } + + pub(crate) async fn start(&mut self) -> Result<()> { + let mut object_cache_storage = + ObjectCacheStorageWrapper::new(self.senders.object_cache_tx().clone()); + + let mut subgroup_group_id: Option = None; + + self.check_and_set_forwarding_preference().await?; + + self.get_and_forward_header(&mut object_cache_storage, &mut subgroup_group_id) + .await?; + + self.forward_loop(&mut object_cache_storage, &subgroup_group_id) + .await?; + + Ok(()) + } + + pub(crate) async fn terminate(&self) -> Result<()> { + let downstream_session_id = self.stream.stable_id(); + let downstream_stream_id = self.stream.stream_id(); + self.senders + .buffer_tx() + .send(BufferCommand::ReleaseStream { + session_id: downstream_session_id, + stream_id: downstream_stream_id, + }) + .await?; + + tracing::info!("Terminated ObjectStreamForwarder"); + + Ok(()) + } + + async fn check_and_set_forwarding_preference(&self) -> Result<()> { + let pubsub_relation_manager = + PubSubRelationManagerWrapper::new(self.senders.pubsub_relation_tx().clone()); + + let downstream_session_id = self.stream.stable_id(); + let downstream_subscribe_id = self.stream.subscribe_id(); + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); + + match pubsub_relation_manager + .get_upstream_forwarding_preference(upstream_session_id, upstream_subscribe_id) + .await? + { + Some(ForwardingPreference::Track) => { + if self.data_stream_type == DataStreamType::StreamHeaderTrack { + pubsub_relation_manager + .set_downstream_forwarding_preference( + downstream_session_id, + downstream_subscribe_id, + ForwardingPreference::Track, + ) + .await?; + } else { + let msg = std::format!( + "uni send stream's data stream type is wrong (expected Track, but got {:?})", + self.data_stream_type + ); + bail!(msg); + } } - pubsub_relation_manager - .set_downstream_forwarding_preference( - downstream_session_id, - downstream_subscribe_id, - ForwardingPreference::Track, + Some(ForwardingPreference::Subgroup) => { + if self.data_stream_type == DataStreamType::StreamHeaderSubgroup { + pubsub_relation_manager + .set_downstream_forwarding_preference( + downstream_session_id, + downstream_subscribe_id, + ForwardingPreference::Subgroup, + ) + .await?; + } else { + let msg = std::format!( + "uni send stream's data stream type is wrong (expected Subgroup, but got {:?})", + self.data_stream_type + ); + bail!(msg); + } + } + + _ => { + let msg = "Invalid forwarding preference"; + bail!(msg); + } + }; + + Ok(()) + } + + async fn get_and_forward_header( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + subgroup_group_id: &mut Option, + ) -> Result<()> { + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); + + // Get the header from the cache storage and send it to the client + let message_buf = match object_cache_storage + .get_header(upstream_session_id, upstream_subscribe_id) + .await? + { + CacheHeader::Track(stream_header_track) => { + self.packetize_forwarding_stream_header_track(&stream_header_track) + .await + } + CacheHeader::Subgroup(stream_header_subgroup) => { + self.packetize_forwarding_stream_header_subgroup( + &stream_header_subgroup, + subgroup_group_id, ) - .await?; - } - Some(ForwardingPreference::Subgroup) => { - if data_stream_type != DataStreamType::StreamHeaderSubgroup { - let msg = std::format!( - "uni send stream's data stream type is wrong (expected Subgroup, but got {:?})", - data_stream_type - ); - senders - .close_session_tx() - .send(( - u8::from(TerminationErrorCode::InternalError) as u64, - msg.clone(), - )) - .await?; + .await + } + _ => { + let msg = "cache header not matched"; bail!(msg) } - pubsub_relation_manager - .set_downstream_forwarding_preference( - downstream_session_id, - downstream_subscribe_id, - ForwardingPreference::Subgroup, - ) - .await?; - } - _ => { - let msg = "Invalid forwarding preference"; - senders - .close_session_tx() - .send(( - u8::from(TerminationErrorCode::ProtocolViolation) as u64, - msg.to_string(), - )) - .await?; - bail!(msg) + }; + + if let Err(e) = self.stream.write_all(&message_buf).await { + tracing::warn!("Failed to write to stream: {:?}", e); + bail!(e); } - }; - - // Get the header from the cache storage and send it to the client - match object_cache_storage - .get_header(upstream_session_id, upstream_subscribe_id) - .await? - { - CacheHeader::Track(header) => { - let mut buf = BytesMut::new(); - let header = StreamHeaderTrack::new( - downstream_subscribe_id, - downstream_track_alias, - header.publisher_priority(), - ) - .unwrap(); - header.packetize(&mut buf); + Ok(()) + } - let mut message_buf = BytesMut::with_capacity(buf.len() + 8); - message_buf.extend(write_variable_integer( - u8::from(DataStreamType::StreamHeaderTrack) as u64, - )); - message_buf.extend(buf); + async fn forward_loop( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + subgroup_group_id: &Option, + ) -> Result<()> { + let mut object_cache_id = None; + let mut is_end = false; - if let Err(e) = stream.send_stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); + loop { + if is_end { + break; } + + (object_cache_id, is_end) = self + .get_and_forward_object(object_cache_storage, object_cache_id, subgroup_group_id) + .await?; } - CacheHeader::Subgroup(header) => { - let mut buf = BytesMut::new(); - let header = StreamHeaderSubgroup::new( - downstream_subscribe_id, - downstream_track_alias, - header.group_id(), - header.subgroup_id(), - header.publisher_priority(), - ) - .unwrap(); - subgroup_header = Some(header.clone()); + Ok(()) + } + + async fn get_and_forward_object( + &mut self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + cache_id: Option, + subgroup_group_id: &Option, + ) -> Result<(Option, bool)> { + // Do loop until get an object from the cache storage + loop { + let cache = match cache_id { + None => self.try_to_get_first_object(object_cache_storage).await?, + Some(cache_id) => { + self.try_to_get_subsequent_object(object_cache_storage, cache_id) + .await? + } + }; - header.packetize(&mut buf); + match cache { + None => { + // If there is no object in the cache storage, sleep for a while and try again + thread::sleep(self.sleep_time); + continue; + } + Some((cache_id, cache_object)) => { + self.packetize_and_forward_object(&cache_object).await?; - let mut message_buf = BytesMut::with_capacity(buf.len() + 8); - message_buf.extend(write_variable_integer( - u8::from(DataStreamType::StreamHeaderSubgroup) as u64, - )); - message_buf.extend(buf); + let is_end = self + .judge_end_of_forwarding(&cache_object, subgroup_group_id) + .await?; - if let Err(e) = stream.send_stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); + return Ok((Some(cache_id), is_end)); + } } } - _ => { - let msg = "cache header not matched"; - senders - .close_session_tx() - .send(( - u8::from(TerminationErrorCode::ProtocolViolation) as u64, - msg.to_string(), - )) - .await?; - bail!(msg) - } } - let mut object_cache_id: Option = None; + async fn try_to_get_first_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + ) -> Result> { + let filter_type = self.downstream_subscription.get_filter_type(); + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); - while object_cache_id.is_none() { - // Get the first object from the cache storage - let result = match filter_type { + match filter_type { FilterType::LatestGroup => { object_cache_storage .get_latest_group(upstream_session_id, upstream_subscribe_id) @@ -205,6 +300,8 @@ pub(crate) async fn forward_object_stream( .await } FilterType::AbsoluteStart | FilterType::AbsoluteRange => { + let (start_group, start_object) = self.downstream_subscription.get_absolute_start(); + object_cache_storage .get_absolute_object( upstream_session_id, @@ -214,147 +311,212 @@ pub(crate) async fn forward_object_stream( ) .await } + } + } + + async fn try_to_get_subsequent_object( + &self, + object_cache_storage: &mut ObjectCacheStorageWrapper, + object_cache_id: usize, + ) -> Result> { + let upstream_session_id = self.object_cache_key.session_id(); + let upstream_subscribe_id = self.object_cache_key.subscribe_id(); + + object_cache_storage + .get_next_object(upstream_session_id, upstream_subscribe_id, object_cache_id) + .await + } + + async fn packetize_and_forward_object(&mut self, cache_object: &CacheObject) -> Result<()> { + let message_buf = match cache_object { + CacheObject::Track(object_stream_track) => { + self.packetize_forwarding_object_stream_track(object_stream_track) + .await + } + CacheObject::Subgroup(object_stream_subgroup) => { + self.packetize_forwarding_object_stream_subgroup(object_stream_subgroup) + .await + } + _ => { + let msg = "cache object not matched"; + bail!(msg) + } }; - object_cache_id = match result { - // Send the object to the client if the first cache is exist as track - Ok(Some((id, CacheObject::Track(object)))) => { - let mut buf = BytesMut::new(); - object.packetize(&mut buf); + if let Err(e) = self.stream.write_all(&message_buf).await { + tracing::warn!("Failed to write to stream: {:?}", e); + bail!(e); + } - let mut message_buf = BytesMut::with_capacity(buf.len()); - message_buf.extend(buf); + Ok(()) + } - if let Err(e) = stream.send_stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); - } + async fn judge_end_of_forwarding( + &self, + cache_object: &CacheObject, + subgroup_group_id: &Option, + ) -> Result { + let is_end_of_data_stream = self.judge_end_of_data_stream(cache_object).await?; + if is_end_of_data_stream { + return Ok(true); + } - Some(id) + let filter_type = self.downstream_subscription.get_filter_type(); + if filter_type == FilterType::AbsoluteRange { + let is_end_of_absolute_range = self + .judge_end_of_absolute_range(cache_object, subgroup_group_id) + .await?; + if is_end_of_absolute_range { + return Ok(true); } - // Send the object to the client if the first cache is exist as subgroup - Ok(Some((id, CacheObject::Subgroup(object)))) => { - let mut buf = BytesMut::new(); - object.packetize(&mut buf); - - let mut message_buf = BytesMut::with_capacity(buf.len()); - message_buf.extend(buf); + } - if let Err(e) = stream.send_stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); - } + Ok(false) + } - Some(id) + async fn judge_end_of_data_stream(&self, cache_object: &CacheObject) -> Result { + let is_end = match cache_object { + CacheObject::Track(object_stream_track) => { + matches!( + object_stream_track.object_status(), + Some(ObjectStatus::EndOfTrackAndGroup) + ) } - // Will be retried if the first cache is not exist - Ok(None) => { - thread::sleep(sleep_time); - object_cache_id + CacheObject::Subgroup(object_stream_subgroup) => { + matches!( + object_stream_subgroup.object_status(), + Some(ObjectStatus::EndOfSubgroup) + | Some(ObjectStatus::EndOfGroup) + | Some(ObjectStatus::EndOfTrackAndGroup) + ) } _ => { - let msg = "cache is not exist"; - senders - .close_session_tx() - .send(( - u8::from(TerminationErrorCode::InternalError) as u64, - msg.to_string(), - )) - .await?; - break; + let msg = "cache object not matched"; + bail!(msg) } }; + + Ok(is_end) } - loop { - // Get the next object from the cache storage - object_cache_id = match object_cache_storage - .get_next_object( - upstream_session_id, - upstream_subscribe_id, - object_cache_id.unwrap(), - ) - .await - { - // Send the object to the client if the next cache is exist as track - Ok(Some((id, CacheObject::Track(object)))) => { - let mut buf = BytesMut::new(); - object.packetize(&mut buf); + async fn judge_end_of_absolute_range( + &self, + cache_object: &CacheObject, + subgroup_group_id: &Option, + ) -> Result { + let (end_group, end_object) = self.downstream_subscription.get_absolute_end(); + let end_group = end_group.unwrap(); + let end_object = end_object.unwrap(); - let mut message_buf = BytesMut::with_capacity(buf.len()); - message_buf.extend(buf); + let is_end = match cache_object { + CacheObject::Track(object_stream_track) => { + let is_group_end = object_stream_track.group_id() == end_group; + let is_object_end = object_stream_track.object_id() == end_object; + let is_ending = is_group_end && is_object_end; - if let Err(e) = stream.send_stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); - } + let is_ended = object_stream_track.group_id() > end_group; - // Judge whether ids are reached to end of the range if the filter type is AbsoluteRange - if filter_type == FilterType::AbsoluteRange { - let is_end = (object.group_id() == end_group.unwrap() - && object.object_id() == end_object.unwrap()) - || (object.group_id() > end_group.unwrap()); + is_ending || is_ended + } + CacheObject::Subgroup(object_stream_subgroup) => { + let subgroup_group_id = subgroup_group_id.unwrap(); + let is_group_end = subgroup_group_id == end_group; + let is_object_end = object_stream_subgroup.object_id() == end_object; + let is_ending = is_group_end && is_object_end; - if is_end { - break; - } - } + let is_ended = subgroup_group_id > end_group; - Some(id) + is_ending || is_ended + } + _ => { + let msg = "cache object not matched"; + bail!(msg) } - // Send the object to the client if the next cache is exist as subgroup - Ok(Some((id, CacheObject::Subgroup(object)))) => { - let mut buf = BytesMut::new(); - object.packetize(&mut buf); + }; - let mut message_buf = BytesMut::with_capacity(buf.len()); - message_buf.extend(buf); + Ok(is_end) + } - if let Err(e) = stream.send_stream.write_all(&message_buf).await { - tracing::warn!("Failed to write to stream: {:?}", e); - bail!(e); - } + async fn packetize_forwarding_stream_header_track( + &self, + header: &StreamHeaderTrack, + ) -> BytesMut { + let mut buf = BytesMut::new(); + let downstream_subscribe_id = self.stream.subscribe_id(); + let downstream_track_alias = self.downstream_subscription.get_track_alias(); + + let header = StreamHeaderTrack::new( + downstream_subscribe_id, + downstream_track_alias, + header.publisher_priority(), + ) + .unwrap(); - // Judge whether ids are reached to end of the range if the filter type is AbsoluteRange - let header = subgroup_header.as_ref().unwrap(); - if filter_type == FilterType::AbsoluteRange { - let is_end = (header.group_id() == end_group.unwrap() - && object.object_id() == end_object.unwrap()) - || (header.group_id() > end_group.unwrap()); + header.packetize(&mut buf); - if is_end { - break; - } - } + let mut message_buf = BytesMut::with_capacity(buf.len() + 8); + message_buf.extend(write_variable_integer( + u8::from(DataStreamType::StreamHeaderTrack) as u64, + )); + message_buf.extend(buf); - Some(id) - } - // Will be retried if the next cache is not exist - Ok(None) => { - thread::sleep(sleep_time); - object_cache_id - } - _ => { - let msg = "cache is not exist"; - senders - .close_session_tx() - .send(( - u8::from(TerminationErrorCode::InternalError) as u64, - msg.to_string(), - )) - .await?; - break; - } - }; + message_buf + } + + async fn packetize_forwarding_stream_header_subgroup( + &self, + header: &StreamHeaderSubgroup, + subgroup_group_id: &mut Option, + ) -> BytesMut { + let mut buf = BytesMut::new(); + let downstream_subscribe_id = self.stream.subscribe_id(); + let downstream_track_alias = self.downstream_subscription.get_track_alias(); + + let header = StreamHeaderSubgroup::new( + downstream_subscribe_id, + downstream_track_alias, + header.group_id(), + header.subgroup_id(), + header.publisher_priority(), + ) + .unwrap(); + + *subgroup_group_id = Some(header.group_id()); + + header.packetize(&mut buf); + + let mut message_buf = BytesMut::with_capacity(buf.len() + 8); + message_buf.extend(write_variable_integer( + u8::from(DataStreamType::StreamHeaderSubgroup) as u64, + )); + message_buf.extend(buf); + + message_buf + } + + async fn packetize_forwarding_object_stream_track( + &self, + object: &ObjectStreamTrack, + ) -> BytesMut { + let mut buf = BytesMut::new(); + object.packetize(&mut buf); + + let mut message_buf = BytesMut::with_capacity(buf.len()); + message_buf.extend(buf); + + message_buf } - senders - .buffer_tx() - .send(BufferCommand::ReleaseStream { - session_id: downstream_session_id, - stream_id: downstream_stream_id, - }) - .await?; + async fn packetize_forwarding_object_stream_subgroup( + &self, + object: &ObjectStreamSubgroup, + ) -> BytesMut { + let mut buf = BytesMut::new(); + object.packetize(&mut buf); - Ok(()) + let mut message_buf = BytesMut::with_capacity(buf.len()); + message_buf.extend(buf); + + message_buf + } } diff --git a/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/streams.rs b/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/streams.rs index 9dfc839..e6ba0d2 100644 --- a/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/streams.rs +++ b/moqt-server/src/modules/server_processes/stream_and_datagram/uni_directional_stream/streams.rs @@ -1,6 +1,9 @@ use std::sync::Arc; use tokio::sync::Mutex; -use wtransport::{error::StreamReadError, RecvStream, SendStream}; +use wtransport::{ + error::{StreamReadError, StreamWriteError}, + RecvStream, SendStream, +}; pub(crate) struct UniRecvStream { stable_id: usize, @@ -9,7 +12,7 @@ pub(crate) struct UniRecvStream { } impl UniRecvStream { - pub fn new( + pub(crate) fn new( stable_id: usize, stream_id: u64, recv_stream: Arc>, @@ -21,15 +24,18 @@ impl UniRecvStream { } } - pub fn stable_id(&self) -> usize { + pub(crate) fn stable_id(&self) -> usize { self.stable_id } - pub fn stream_id(&self) -> u64 { + pub(crate) fn stream_id(&self) -> u64 { self.stream_id } - pub async fn read(&mut self, buffer: &mut [u8]) -> Result, StreamReadError> { + pub(crate) async fn read( + &mut self, + buffer: &mut [u8], + ) -> Result, StreamReadError> { let mut recv_stream = self.recv_stream.lock().await; recv_stream.read(buffer).await } @@ -39,7 +45,7 @@ pub(crate) struct UniSendStream { stable_id: usize, stream_id: u64, subscribe_id: u64, - pub(crate) send_stream: SendStream, + send_stream: SendStream, } impl UniSendStream { @@ -57,15 +63,19 @@ impl UniSendStream { } } - pub fn stable_id(&self) -> usize { + pub(crate) fn stable_id(&self) -> usize { self.stable_id } - pub fn stream_id(&self) -> u64 { + pub(crate) fn stream_id(&self) -> u64 { self.stream_id } - pub fn subscribe_id(&self) -> u64 { + pub(crate) fn subscribe_id(&self) -> u64 { self.subscribe_id } + + pub(crate) async fn write_all(&mut self, buffer: &[u8]) -> Result<(), StreamWriteError> { + self.send_stream.write_all(buffer).await + } } diff --git a/moqt-server/src/modules/server_processes/thread_starters.rs b/moqt-server/src/modules/server_processes/thread_starters.rs index 9eb216e..509bba0 100644 --- a/moqt-server/src/modules/server_processes/thread_starters.rs +++ b/moqt-server/src/modules/server_processes/thread_starters.rs @@ -2,8 +2,9 @@ use super::stream_and_datagram::{ bi_directional_stream::{ forwarder::forward_control_message, receiver::handle_bi_recv_stream, stream::BiStream, }, + datagram::{forwarder::DatagramForwarder, receiver::DatagramReceiver}, uni_directional_stream::{ - forwarder::forward_object_stream, + forwarder::ObjectStreamForwarder, receiver::UniStreamReceiver, streams::{UniRecvStream, UniSendStream}, }, @@ -18,7 +19,7 @@ use moqt_core::{ use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use tracing::{self, Instrument}; -use wtransport::{Connection, RecvStream, SendStream}; +use wtransport::{datagram::Datagram, Connection, RecvStream, SendStream}; async fn spawn_bi_stream_threads( client: Arc>, @@ -146,19 +147,121 @@ async fn spawn_uni_send_stream_thread( tokio::spawn( async move { - let mut stream = UniSendStream::new(stable_id, stream_id, subscribe_id, send_stream); - forward_object_stream(&mut stream, client, data_stream_type) + let stream = UniSendStream::new(stable_id, stream_id, subscribe_id, send_stream); + let senders = client.lock().await.senders(); + + let mut object_stream_forwarder = + ObjectStreamForwarder::init(stream, client, data_stream_type) + .await + .unwrap(); + + match object_stream_forwarder + .start() + .instrument(session_span) + .await + { + Ok(_) => {} + Err(e) => { + let code = TerminationErrorCode::InternalError; + let reason = format!("ObjectStreamForwarder: {:?}", e); + + tracing::error!(reason); + + let _ = senders + .close_session_tx() + .send((u8::from(code) as u64, reason.to_string())) + .await; + } + } + + let _ = object_stream_forwarder.terminate().await; + } + .in_current_span(), + ); + Ok(()) +} + +async fn spawn_recv_datagram_thread( + client: Arc>, + datagram: Datagram, +) -> Result<()> { + let stable_id = client.lock().await.id(); + let session_span = tracing::info_span!("Session", stable_id); + session_span.in_scope(|| { + tracing::info!("Accepted Datagram"); + }); + + // No loop: End after receiving once + tokio::spawn( + async move { + let senders = client.lock().await.senders(); + let mut datagram_receiver = DatagramReceiver::init(client).await; + + match datagram_receiver + .receive(datagram) .instrument(session_span) .await + { + Ok(_) => {} + Err((code, reason)) => { + tracing::error!(reason); + + let _ = senders + .close_session_tx() + .send((u8::from(code) as u64, reason.to_string())) + .await; + } + } + } + .in_current_span(), + ); + Ok(()) +} + +async fn spawn_send_datagram_thread( + client: Arc>, + session: Arc, + subscribe_id: u64, +) -> Result<()> { + let stable_id = client.lock().await.id(); + let session_span = tracing::info_span!("Session", stable_id); + session_span.in_scope(|| { + tracing::info!("Accepted Datagram"); + }); + + tokio::spawn( + async move { + let senders = client.lock().await.senders(); + let mut datagram_forwarder = DatagramForwarder::init(session, subscribe_id, client) + .await + .unwrap(); + + match datagram_forwarder.start().instrument(session_span).await { + Ok(_) => {} + Err(e) => { + let code = TerminationErrorCode::InternalError; + let reason = format!("DatagramForwarder: {:?}", e); + + tracing::error!(reason); + + let _ = senders + .close_session_tx() + .send((u8::from(code) as u64, reason.to_string())) + .await; + } + } + + let _ = datagram_forwarder.terminate().await; } .in_current_span(), ); + Ok(()) } pub(crate) async fn select_spawn_thread( client: &Arc>, - session: &Connection, + session: Arc, open_downstream_stream_or_datagram_rx: &mut mpsc::Receiver<(u64, DataStreamType)>, close_session_rx: &mut mpsc::Receiver<(u64, String)>, is_control_stream_opened: &mut bool, @@ -173,6 +276,10 @@ pub(crate) async fn select_spawn_thread( let recv_stream = stream?; spawn_uni_recv_stream_thread(client.clone(), recv_stream).await?; }, + datagram = session.receive_datagram() => { + let datagram = datagram?; + spawn_recv_datagram_thread(client.clone(), datagram).await?; + }, // Waiting for a uni-directional send stream open request and forwarding the message Some((subscribe_id, data_stream_type)) = open_downstream_stream_or_datagram_rx.recv() => { match data_stream_type { @@ -181,8 +288,9 @@ pub(crate) async fn select_spawn_thread( spawn_uni_send_stream_thread(client.clone(), send_stream, subscribe_id, data_stream_type).await?; } DataStreamType::ObjectDatagram => { - // TODO: Open datagram thread - unimplemented!(); + let session = session.clone(); + spawn_send_datagram_thread(client.clone(), session, subscribe_id).await?; + } } },