Skip to content

Commit

Permalink
Refactored stream API to make malformed packet detection more reliable
Browse files Browse the repository at this point in the history
  • Loading branch information
ajmcquilkin committed Mar 15, 2024
1 parent c2b4622 commit 7cd485e
Showing 1 changed file with 115 additions and 33 deletions.
148 changes: 115 additions & 33 deletions src/connections/stream_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pub enum StreamBufferError {
DecodeFailure(#[from] prost::DecodeError),
}

const PACKET_HEADER_SIZE: usize = 4;

impl StreamBuffer {
/// Creates a new StreamBuffer instance that will send decoded FromRadio packets
/// to the given broadcast channel.
Expand Down Expand Up @@ -111,11 +113,11 @@ impl StreamBuffer {
trace!("Packet buffer: {:?}", self.buffer);

// Check that the buffer can potentially contain a packet header
if self.buffer.len() < 4 {
if self.buffer.len() < PACKET_HEADER_SIZE {
debug!("Buffer data is shorter than packet header size, failing");
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: 4,
packet_size: PACKET_HEADER_SIZE,
});
}

Expand All @@ -135,17 +137,53 @@ impl StreamBuffer {
);

self.buffer = self.buffer[framing_index..].to_vec();

log::trace!("Buffer after shifting: {:?}", self.buffer);

framing_index = self.get_framing_index()?;
}

// Note: the framing index should always be 0 at this point, keeping for clarity
let incoming_packet_data_size = self.extract_data_size_from_header(framing_index)?;

self.validate_packet_in_buffer(incoming_packet_data_size, framing_index)?;

// Get packet data, excluding magic bytes
let packet_data =
self.extract_packet_from_buffer(incoming_packet_data_size, framing_index)?;

// Attempt to decode the current packet
let decoded_packet = protobufs::FromRadio::decode(packet_data.as_slice())?;

Ok(decoded_packet)
}

// All valid packets start with the sequence [0x94 0xc3 size_msb size_lsb], where
// size_msb and size_lsb collectively give the size of the incoming packet
// Note that the maximum packet size currently stands at 240 bytes, meaning an MSB is not needed
fn get_framing_index(&mut self) -> Result<usize, StreamBufferError> {
match self.buffer.iter().position(|&b| b == 0x94) {
Some(idx) => Ok(idx),
None => {
warn!("Could not find index of 0x94, purging buffer");
self.buffer.clear(); // Clear buffer since no packets exist
Err(StreamBufferError::MissingHeaderByte)
}
}
}

fn extract_data_size_from_header(
&self,
framing_index: usize,
) -> Result<usize, StreamBufferError> {
// Get the "framing byte" after the start of the packet header, or fail if not found
let framing_byte = match self.buffer.get(framing_index + 1) {
Some(val) => val,
None => {
debug!("Could not find framing byte, waiting for more data");
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: 4,
packet_size: PACKET_HEADER_SIZE,
});
}
};
Expand Down Expand Up @@ -180,26 +218,47 @@ impl StreamBuffer {

// Combine MSB and LSB of incoming packet size bytes
// Recall that packet size doesn't include the first four magic bytes
let incoming_packet_size: usize = usize::from(4 + u16::from_le_bytes([*lsb, *msb]));
let incoming_packet_data_size: usize = usize::from(u16::from_le_bytes([*lsb, *msb]));

return Ok(incoming_packet_data_size);
}

// Defer decoding until the correct number of bytes are received
if self.buffer.len() < incoming_packet_size {
warn!("Stream buffer size is less than size of packet");
fn validate_packet_in_buffer(
&mut self,
packet_data_size: usize,
framing_index: usize,
) -> Result<(), StreamBufferError> {
if self.buffer.len() < PACKET_HEADER_SIZE + packet_data_size {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: incoming_packet_size,
packet_size: packet_data_size,
});
}

// Get packet data, excluding magic bytes
let packet: Vec<u8> = self.buffer[4..incoming_packet_size].to_vec();

// Packet is malformed if the start of another packet occurs within the
// defined limits of the current packet
let malformed_packet_detector_index = packet.iter().position(|&b| b == 0x94);
let packet_data_start_index = framing_index + PACKET_HEADER_SIZE;

trace!(
"Validating bytes in range [{}, {})",
packet_data_start_index,
packet_data_start_index + packet_data_size
);

// Packet is malformed if the start of another packet occurs within the defined limits of the current packet
let malformed_packet_detector_index = self
.buffer
.iter()
.enumerate()
// Only want to check within the range of the current packet's data (not header)
.filter(|&(i, _)| {
packet_data_start_index <= i && i < packet_data_start_index + packet_data_size
})
.position(|(_, b)| *b == 0x94)
// `position` returns the index from the filtered array, need to re-normalize to the original buffer
.map(|idx| idx + packet_data_start_index);

let malformed_packet_detector_byte = if let Some(index) = malformed_packet_detector_index {
packet.get(index + 1)
trace!("Found 0x94 at index {}", index);
self.buffer.get(index + 1)
} else {
None
};
Expand All @@ -220,27 +279,45 @@ impl StreamBuffer {
});
}

// Remove current packet from buffer based on start location of next packet
self.buffer = self.buffer[incoming_packet_size..].to_vec();

// Attempt to decode the current packet
let decoded_packet = protobufs::FromRadio::decode(packet.as_slice())?;

Ok(decoded_packet)
Ok(())
}

// All valid packets start with the sequence [0x94 0xc3 size_msb size_lsb], where
// size_msb and size_lsb collectively give the size of the incoming packet
// Note that the maximum packet size currently stands at 240 bytes, meaning an MSB is not needed
fn get_framing_index(&mut self) -> Result<usize, StreamBufferError> {
match self.buffer.iter().position(|&b| b == 0x94) {
Some(idx) => Ok(idx),
None => {
warn!("Could not find index of 0x94, purging buffer");
self.buffer.clear(); // Clear buffer since no packets exist
Err(StreamBufferError::MissingHeaderByte)
}
fn extract_packet_from_buffer(
&mut self,
packet_data_size: usize,
framing_index: usize,
) -> Result<Vec<u8>, StreamBufferError> {
if self.buffer.len() < packet_data_size {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: packet_data_size,
});
}

let packet_size = PACKET_HEADER_SIZE + packet_data_size;

// Extract packet with header before removing header
let mut packet_data_with_header: Vec<u8> =
self.buffer.drain(framing_index..packet_size).collect();

trace!(
"Extracted packet data with header of length {:?} from buffer: {:?}",
packet_data_with_header.len(),
packet_data_with_header
);

// Remove header bytes
let packet_data: Vec<u8> = packet_data_with_header
.drain(PACKET_HEADER_SIZE..)
.collect();

trace!(
"Extracted packet data of length {:?} from buffer: {:?}",
packet_data.len(),
packet_data
);

Ok(packet_data)
}
}

Expand Down Expand Up @@ -376,4 +453,9 @@ mod tests {

assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(valid_packet));
}

// #[tokio::test]
// async fn should_handle_incomplete_header_at_start_of_buffer() {
// // TODO
// }
}

0 comments on commit 7cd485e

Please sign in to comment.