From 4e13bbe8a4f55a2278d1f93c75c6d1cbd25927aa Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Tue, 6 Aug 2024 09:24:20 +0200 Subject: [PATCH] Do not decrypt already received packets --- src/session.rs | 12 +++++++ src/streams/receive.rs | 11 +++++++ src/streams/register.rs | 4 +++ src/streams/register_nack.rs | 61 ++++++++++++++++++++++++++++-------- 4 files changed, 75 insertions(+), 13 deletions(-) diff --git a/src/session.rs b/src/session.rs index db6a3f51..603613b6 100644 --- a/src/session.rs +++ b/src/session.rs @@ -388,6 +388,18 @@ impl Session { // Either way we get a seq_no_outer which is used to decrypt the SRTP. let mut seq_no = stream.extend_seq(&header, is_repair, max_seq_lookup); + if !stream.is_new_packet(is_repair, seq_no) { + // Dupe packet. This could be a potential SRTP replay attack, which means + // we should not spend any CPU cycles towards decrypting it. + trace!( + "Ignoring dupe packet mid: {} seq_no: {} is_repair: {}", + mid, + seq_no, + is_repair + ); + return; + } + let mut data = match srtp.unprotect_rtp(buf, &header, *seq_no) { Some(v) => v, None => { diff --git a/src/streams/receive.rs b/src/streams/receive.rs index dbf3e840..9d6ee133 100644 --- a/src/streams/receive.rs +++ b/src/streams/receive.rs @@ -346,6 +346,17 @@ impl StreamRx { } } + pub(crate) fn is_new_packet(&self, is_repair: bool, seq_no: SeqNo) -> bool { + let register_ref = if is_repair { + self.register_rtx.as_ref() + } else { + self.register.as_ref() + }; + + // Unwrap is OK because we always call extend_seq() for the same is_repair flag beforehand + register_ref.unwrap().accepts(seq_no) + } + pub(crate) fn update_register( &mut self, now: Instant, diff --git a/src/streams/register.rs b/src/streams/register.rs index 530184df..65134b1a 100644 --- a/src/streams/register.rs +++ b/src/streams/register.rs @@ -73,6 +73,10 @@ impl ReceiverRegister { } } + pub fn accepts(&self, seq: SeqNo) -> bool { + self.nack.accepts(seq) + } + pub fn update(&mut self, seq: SeqNo, arrival: Instant, rtp_time: u32, clock_rate: u32) -> bool { if self.first.is_none() { self.first = Some(seq); diff --git a/src/streams/register_nack.rs b/src/streams/register_nack.rs index 0b55b16c..3e29ac05 100644 --- a/src/streams/register_nack.rs +++ b/src/streams/register_nack.rs @@ -55,20 +55,21 @@ impl<'a> Iterator for NackIterator<'a> { type Item = NackEntry; fn next(&mut self) -> Option { - self.next = (self.next..=self.end).find(|s| self.reg.packet((*s).into()).needs_nack())?; + self.next = + (self.next..=self.end).find(|s| self.reg.packet_mut((*s).into()).needs_nack())?; let mut entry = NackEntry { pid: (self.next % U16_MAX) as u16, blp: 0, }; - self.reg.packet(self.next.into()).nack_count += 1; + self.reg.packet_mut(self.next.into()).nack_count += 1; self.next += 1; for (i, s) in (self.next..self.end).take(16).enumerate() { - let packet = self.reg.packet(s.into()); + let packet = self.reg.packet_mut(s.into()); if packet.needs_nack() { - self.reg.packet(self.next.into()).nack_count += 1; + self.reg.packet_mut(self.next.into()).nack_count += 1; entry.blp |= 1 << i } self.next += 1; @@ -97,6 +98,20 @@ impl NackRegister { n } + pub fn accepts(&self, seq: SeqNo) -> bool { + let Some(active) = self.active.clone() else { + // if we don't have initialized, we do want the first packet. + return true; + }; + + // behind the window + if seq < active.start { + return false; + } + + !self.packet(seq).received || seq > active.end + } + pub fn update(&mut self, seq: SeqNo) -> bool { let Some(active) = self.active.clone() else { // automatically pick up the first seq number @@ -109,7 +124,7 @@ impl NackRegister { return false; } - let new = !self.packet(seq).received || seq > active.end; + let new = !self.packet_mut(seq).received || seq > active.end; let end = active.end.max(seq); @@ -117,7 +132,7 @@ impl NackRegister { let min = end.saturating_sub(MAX_MISORDER); let mut start = (*active.start).max(min); while start < *end { - if !self.packet(start.into()).received && start != *seq { + if !self.packet_mut(start.into()).received && start != *seq { break; } start += 1; @@ -127,11 +142,11 @@ impl NackRegister { // reset packets that are rolling our of the nack window for (i, s) in (*active.start..*start).enumerate() { - let p = self.packet(s.into()); + let p = self.packet_mut(s.into()); if !p.received && s != *seq { debug!("Seq no {} missing after {} attempts", s, p.nack_count); } - self.packet(s.into()).reset(); + self.packet_mut(s.into()).reset(); if i > self.packets.len() { // we have reset all entries already @@ -140,7 +155,7 @@ impl NackRegister { } if (start..=end).contains(&seq) { - self.packet(seq).mark_received(); + self.packet_mut(seq).mark_received(); } self.active = Some(start..end); @@ -150,7 +165,7 @@ impl NackRegister { fn init_with_seq(&mut self, seq: SeqNo) { self.active = Some(seq..seq); - self.packet(seq).mark_received(); + self.packet_mut(seq).mark_received(); } pub fn max_seq(&self) -> Option { @@ -162,7 +177,7 @@ impl NackRegister { /// This modifies the state as it counts how many times packets have been nacked pub fn nack_reports(&mut self) -> Option> { let Range { start, end } = self.active.clone()?; - let start = (*start..=*end).find(|s| self.packet((*s).into()).needs_nack())?; + let start = (*start..=*end).find(|s| self.packet_mut((*s).into()).needs_nack())?; Some( ReportList::lists_from_iter(NackIterator { @@ -185,7 +200,12 @@ impl NackRegister { (*seq % self.packets.len() as u64) as usize } - fn packet(&mut self, seq: SeqNo) -> &mut PacketStatus { + fn packet(&self, seq: SeqNo) -> &PacketStatus { + let index = self.as_index(seq); + &self.packets[index] + } + + fn packet_mut(&mut self, seq: SeqNo) -> &mut PacketStatus { let index = self.as_index(seq); &mut self.packets[index] } @@ -215,7 +235,7 @@ mod test { ); let active = reg.active.clone().expect("nack range"); assert_eq!( - reg.packet(seq.into()).received, + reg.packet_mut(seq.into()).received, expect_received, "seq {} expected to{} be received in {:?}", seq, @@ -250,34 +270,49 @@ mod test { fn active_window_sliding() { let mut reg = NackRegister::new(None); + assert!(reg.accepts(10.into())); assert_update(&mut reg, 10, true, true, 10..10); // packet before window start is ignored + assert!(!reg.accepts(9.into())); assert_update(&mut reg, 9, false, false, 10..10); // duped packet + assert!(!reg.accepts(10.into())); assert_update(&mut reg, 10, false, true, 10..10); // future packets accepted, window not sliding let next = 10 + MAX_MISORDER; + assert!(reg.accepts(next.into())); assert_update(&mut reg, next, true, true, 11..next); let next = 11 + MAX_MISORDER; + assert!(reg.accepts(next.into())); assert_update(&mut reg, next, true, true, 11..next); // future packet accepted, sliding window let next = 12 + MAX_MISORDER; + assert!(reg.accepts(next.into())); assert_update(&mut reg, next, true, true, 12..next); // older packet received within window let next = 13; + assert!(reg.accepts(next.into())); assert_update(&mut reg, next, true, true, 12..(12 + MAX_MISORDER)); + // do not want the same packet again + assert!(!reg.accepts(next.into())); + // future packet accepted, sliding window start skips over received let next = 13 + MAX_MISORDER; + assert!(reg.accepts(next.into())); assert_update(&mut reg, next, true, true, 14..next); + // do not want the same packet again + assert!(!reg.accepts(next.into())); + // older packet accepted, window star moves ahead let next = 14; + assert!(reg.accepts(next.into())); assert_update(&mut reg, next, true, false, 15..(13 + MAX_MISORDER)); }