From 7cd485e19cb54e49fb2b8ede2e0e93cd483a21fc Mon Sep 17 00:00:00 2001 From: Adam McQuilkin <46639306+ajmcquilkin@users.noreply.github.com> Date: Fri, 15 Mar 2024 13:49:46 -0700 Subject: [PATCH] Refactored stream API to make malformed packet detection more reliable --- src/connections/stream_buffer.rs | 148 ++++++++++++++++++++++++------- 1 file changed, 115 insertions(+), 33 deletions(-) diff --git a/src/connections/stream_buffer.rs b/src/connections/stream_buffer.rs index cb97976..e36e599 100644 --- a/src/connections/stream_buffer.rs +++ b/src/connections/stream_buffer.rs @@ -42,6 +42,8 @@ pub enum StreamBufferError { DecodeFailure(#[from] prost::DecodeError), } +const PACKET_HEADER_SIZE: usize = 4; + impl StreamBuffer { /// Creates a new StreamBuffer instance that will send decoded FromRadio packets /// to the given broadcast channel. @@ -111,11 +113,11 @@ impl StreamBuffer { trace!("Packet buffer: {:?}", self.buffer); // Check that the buffer can potentially contain a packet header - if self.buffer.len() < 4 { + if self.buffer.len() < PACKET_HEADER_SIZE { debug!("Buffer data is shorter than packet header size, failing"); return Err(StreamBufferError::IncompletePacket { buffer_size: self.buffer.len(), - packet_size: 4, + packet_size: PACKET_HEADER_SIZE, }); } @@ -135,9 +137,45 @@ impl StreamBuffer { ); self.buffer = self.buffer[framing_index..].to_vec(); + + log::trace!("Buffer after shifting: {:?}", self.buffer); + framing_index = self.get_framing_index()?; } + // Note: the framing index should always be 0 at this point, keeping for clarity + let incoming_packet_data_size = self.extract_data_size_from_header(framing_index)?; + + self.validate_packet_in_buffer(incoming_packet_data_size, framing_index)?; + + // Get packet data, excluding magic bytes + let packet_data = + self.extract_packet_from_buffer(incoming_packet_data_size, framing_index)?; + + // Attempt to decode the current packet + let decoded_packet = protobufs::FromRadio::decode(packet_data.as_slice())?; + + Ok(decoded_packet) + } + + // All valid packets start with the sequence [0x94 0xc3 size_msb size_lsb], where + // size_msb and size_lsb collectively give the size of the incoming packet + // Note that the maximum packet size currently stands at 240 bytes, meaning an MSB is not needed + fn get_framing_index(&mut self) -> Result { + match self.buffer.iter().position(|&b| b == 0x94) { + Some(idx) => Ok(idx), + None => { + warn!("Could not find index of 0x94, purging buffer"); + self.buffer.clear(); // Clear buffer since no packets exist + Err(StreamBufferError::MissingHeaderByte) + } + } + } + + fn extract_data_size_from_header( + &self, + framing_index: usize, + ) -> Result { // Get the "framing byte" after the start of the packet header, or fail if not found let framing_byte = match self.buffer.get(framing_index + 1) { Some(val) => val, @@ -145,7 +183,7 @@ impl StreamBuffer { debug!("Could not find framing byte, waiting for more data"); return Err(StreamBufferError::IncompletePacket { buffer_size: self.buffer.len(), - packet_size: 4, + packet_size: PACKET_HEADER_SIZE, }); } }; @@ -180,26 +218,47 @@ impl StreamBuffer { // Combine MSB and LSB of incoming packet size bytes // Recall that packet size doesn't include the first four magic bytes - let incoming_packet_size: usize = usize::from(4 + u16::from_le_bytes([*lsb, *msb])); + let incoming_packet_data_size: usize = usize::from(u16::from_le_bytes([*lsb, *msb])); + + return Ok(incoming_packet_data_size); + } - // Defer decoding until the correct number of bytes are received - if self.buffer.len() < incoming_packet_size { - warn!("Stream buffer size is less than size of packet"); + fn validate_packet_in_buffer( + &mut self, + packet_data_size: usize, + framing_index: usize, + ) -> Result<(), StreamBufferError> { + if self.buffer.len() < PACKET_HEADER_SIZE + packet_data_size { return Err(StreamBufferError::IncompletePacket { buffer_size: self.buffer.len(), - packet_size: incoming_packet_size, + packet_size: packet_data_size, }); } - // Get packet data, excluding magic bytes - let packet: Vec = self.buffer[4..incoming_packet_size].to_vec(); - - // Packet is malformed if the start of another packet occurs within the - // defined limits of the current packet - let malformed_packet_detector_index = packet.iter().position(|&b| b == 0x94); + let packet_data_start_index = framing_index + PACKET_HEADER_SIZE; + + trace!( + "Validating bytes in range [{}, {})", + packet_data_start_index, + packet_data_start_index + packet_data_size + ); + + // Packet is malformed if the start of another packet occurs within the defined limits of the current packet + let malformed_packet_detector_index = self + .buffer + .iter() + .enumerate() + // Only want to check within the range of the current packet's data (not header) + .filter(|&(i, _)| { + packet_data_start_index <= i && i < packet_data_start_index + packet_data_size + }) + .position(|(_, b)| *b == 0x94) + // `position` returns the index from the filtered array, need to re-normalize to the original buffer + .map(|idx| idx + packet_data_start_index); let malformed_packet_detector_byte = if let Some(index) = malformed_packet_detector_index { - packet.get(index + 1) + trace!("Found 0x94 at index {}", index); + self.buffer.get(index + 1) } else { None }; @@ -220,27 +279,45 @@ impl StreamBuffer { }); } - // Remove current packet from buffer based on start location of next packet - self.buffer = self.buffer[incoming_packet_size..].to_vec(); - - // Attempt to decode the current packet - let decoded_packet = protobufs::FromRadio::decode(packet.as_slice())?; - - Ok(decoded_packet) + Ok(()) } - // All valid packets start with the sequence [0x94 0xc3 size_msb size_lsb], where - // size_msb and size_lsb collectively give the size of the incoming packet - // Note that the maximum packet size currently stands at 240 bytes, meaning an MSB is not needed - fn get_framing_index(&mut self) -> Result { - match self.buffer.iter().position(|&b| b == 0x94) { - Some(idx) => Ok(idx), - None => { - warn!("Could not find index of 0x94, purging buffer"); - self.buffer.clear(); // Clear buffer since no packets exist - Err(StreamBufferError::MissingHeaderByte) - } + fn extract_packet_from_buffer( + &mut self, + packet_data_size: usize, + framing_index: usize, + ) -> Result, StreamBufferError> { + if self.buffer.len() < packet_data_size { + return Err(StreamBufferError::IncompletePacket { + buffer_size: self.buffer.len(), + packet_size: packet_data_size, + }); } + + let packet_size = PACKET_HEADER_SIZE + packet_data_size; + + // Extract packet with header before removing header + let mut packet_data_with_header: Vec = + self.buffer.drain(framing_index..packet_size).collect(); + + trace!( + "Extracted packet data with header of length {:?} from buffer: {:?}", + packet_data_with_header.len(), + packet_data_with_header + ); + + // Remove header bytes + let packet_data: Vec = packet_data_with_header + .drain(PACKET_HEADER_SIZE..) + .collect(); + + trace!( + "Extracted packet data of length {:?} from buffer: {:?}", + packet_data.len(), + packet_data + ); + + Ok(packet_data) } } @@ -376,4 +453,9 @@ mod tests { assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(valid_packet)); } + + // #[tokio::test] + // async fn should_handle_incomplete_header_at_start_of_buffer() { + // // TODO + // } }