Skip to content

Commit

Permalink
fix: poll_read: wrong behavior on half close (#362)
Browse files Browse the repository at this point in the history
* fix: poll_read: wrong behavior on half close

* enhance naming
  • Loading branch information
umiro authored Aug 31, 2022
1 parent 19907a9 commit dceb87c
Showing 1 changed file with 113 additions and 37 deletions.
150 changes: 113 additions & 37 deletions yamux/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ impl StreamHandle {
if flags.contains(Flag::Ack) && self.state == StreamState::SynSent {
self.state = StreamState::SynReceived;
}
let mut close_stream = false;
if flags.contains(Flag::Fin) {
match self.state {
StreamState::Init
Expand All @@ -186,18 +185,12 @@ impl StreamHandle {
}
StreamState::LocalClosing => {
self.state = StreamState::Closed;
close_stream = true;
}
_ => return Err(Error::UnexpectedFlag),
}
}
if flags.contains(Flag::Rst) {
self.state = StreamState::Reset;
close_stream = true;
}

if close_stream {
self.close()?;
}
Ok(())
}
Expand Down Expand Up @@ -308,25 +301,24 @@ impl StreamHandle {
Ok(())
}

fn check_self_state(&mut self) -> Result<(), io::Error> {
// Returns Ok(true) only if eof is reached.
fn check_self_state(&mut self) -> io::Result<bool> {
// if read buf is empty and state is close, return close error
if self.read_buf.is_empty() {
match self.state {
StreamState::RemoteClosing => {
StreamState::RemoteClosing | StreamState::Closed => {
debug!("closed(EOF)");
let _ignore = self.send_close();
Err(io::ErrorKind::UnexpectedEof.into())
// an empty read indicates that EOF is reached.
Ok(true)
}
StreamState::Reset => {
debug!("connection reset");
let _ignore = self.send_close();
Err(io::ErrorKind::ConnectionReset.into())
}
StreamState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
_ => Ok(()),
_ => Ok(false),
}
} else {
Ok(())
Ok(false)
}
}
}
Expand All @@ -337,7 +329,9 @@ impl AsyncRead for StreamHandle {
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.check_self_state()?;
if self.check_self_state()? {
return Poll::Ready(Ok(()));
}

if let Err(e) = self.recv_frames(cx) {
match e {
Expand All @@ -350,7 +344,9 @@ impl AsyncRead for StreamHandle {
}
}

self.check_self_state()?;
if self.check_self_state()? {
return Poll::Ready(Ok(()));
}

let n = ::std::cmp::min(buf.remaining(), self.read_buf.len());
trace!(
Expand All @@ -367,11 +363,8 @@ impl AsyncRead for StreamHandle {

buf.put_slice(&b);
match self.state {
StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => (),
StreamState::LocalClosing => {
if self.close().is_err() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => {
debug!("this branch should be unreachable")
}
_ => {
if self.send_window_update().is_err() {
Expand All @@ -391,9 +384,7 @@ impl AsyncWrite for StreamHandle {
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.state {
StreamState::RemoteClosing | StreamState::Reset => {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
StreamState::Reset => return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
StreamState::LocalClosing | StreamState::Closed => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
Expand Down Expand Up @@ -446,18 +437,27 @@ impl AsyncWrite for StreamHandle {
impl Drop for StreamHandle {
fn drop(&mut self) {
if !self.unbound_event_sender.is_closed() && self.state != StreamState::Closed {
let event = StreamEvent::Closed(self.id);
// LocalClosing means that local have sent Fin to the remote and waiting for a response.
// if not, we should send Rst first
if self.state != StreamState::LocalClosing {
let mut flags = self.get_flags();
flags.add(Flag::Rst);
let frame = Frame::new_window_update(flags, self.id, 0);
let rst_event = StreamEvent::Frame(frame);

// Always successful unless the session is dropped
let _ignore = self.unbound_event_sender.unbounded_send(rst_event);
match self.state {
// LocalClosing means that local have sent Fin to the remote and waiting for a response.
StreamState::LocalClosing | StreamState::Reset => (),
// if not, we should send Rst first
StreamState::Established
| StreamState::Init
| StreamState::RemoteClosing
| StreamState::SynReceived
| StreamState::SynSent => {
let mut flags = self.get_flags();
flags.add(Flag::Rst);
let frame = Frame::new_window_update(flags, self.id, 0);
let rst_event = StreamEvent::Frame(frame);

// Always successful unless the session is dropped
let _ignore = self.unbound_event_sender.unbounded_send(rst_event);
}
StreamState::Closed => unreachable!(),
}

let event = StreamEvent::Closed(self.id);
let _ignore = self.unbound_event_sender.unbounded_send(event);
}
}
Expand Down Expand Up @@ -560,10 +560,13 @@ mod test {
// try poll stream handle, then it will recv RST frame and set self state to reset
assert_eq!(
stream.read(&mut b).await.unwrap_err().kind(),
ErrorKind::BrokenPipe
ErrorKind::ConnectionReset
);

assert_eq!(stream.state, StreamState::Reset);

drop(stream);

let event = unbound_receiver.next().await.unwrap();
match event {
StreamEvent::Closed(_) => (),
Expand Down Expand Up @@ -681,4 +684,77 @@ mod test {
}
});
}

#[test]
fn test_read_with_half_close() {
let rt = rt();
rt.block_on(async {
let (mut frame_sender, frame_receiver) = channel(2);
let (unbound_sender, _unbound_receiver) = unbounded();
let mut stream = StreamHandle::new(
0,
unbound_sender,
frame_receiver,
StreamState::Init,
INITIAL_STREAM_WINDOW,
);

stream.shutdown().await.unwrap();

assert_eq!(stream.state, StreamState::LocalClosing);

let flags = Flags::from(Flag::Syn);
let frame = Frame::new_data(flags, 0, BytesMut::from("1234"));
frame_sender.send(frame).await.unwrap();
let mut b = [0; 1024];

assert_eq!(stream.read(&mut b).await.unwrap(), 4);
assert_eq!(&b[..4], b"1234");

assert_eq!(stream.state, StreamState::LocalClosing);
});
}

#[test]
fn test_write_with_half_close() {
let rt = rt();
rt.block_on(async {
let (mut frame_sender, frame_receiver) = channel(2);
let (unbound_sender, mut unbound_receiver) = unbounded();
let mut stream = StreamHandle::new(
0,
unbound_sender,
frame_receiver,
StreamState::Init,
INITIAL_STREAM_WINDOW,
);

let flags = Flags::from(Flag::Fin);
let frame = Frame::new_window_update(flags, 0, 0);
frame_sender.send(frame).await.unwrap();
let mut b = [0; 1024];

assert_eq!(stream.read(&mut b).await.unwrap(), 0);
assert_eq!(stream.state, StreamState::RemoteClosing);

const TEXT: &[u8] = b"testtext";

let jh = tokio::spawn(tokio::time::timeout(std::time::Duration::from_secs(4), async move {
loop {
match unbound_receiver.try_next() {
Ok(Some(ref event)) if matches!(event, StreamEvent::Frame(frame) if frame.length() == TEXT.len() as u32) => break,
Err(_) => (),
_ => panic!("must be frame with written text"),
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
}));

stream.write_all(TEXT).await.unwrap();

jh.await.unwrap().expect("not tiemout");

assert_eq!(stream.state, StreamState::RemoteClosing);
});
}
}

0 comments on commit dceb87c

Please sign in to comment.