diff --git a/yamux/src/stream.rs b/yamux/src/stream.rs index e75ce679..790ad8dc 100644 --- a/yamux/src/stream.rs +++ b/yamux/src/stream.rs @@ -175,7 +175,6 @@ impl StreamHandle { if flags.contains(Flag::Ack) && self.state == StreamState::SynSent { self.state = StreamState::SynReceived; } - let mut close_stream = false; if flags.contains(Flag::Fin) { match self.state { StreamState::Init @@ -186,18 +185,12 @@ impl StreamHandle { } StreamState::LocalClosing => { self.state = StreamState::Closed; - close_stream = true; } _ => return Err(Error::UnexpectedFlag), } } if flags.contains(Flag::Rst) { self.state = StreamState::Reset; - close_stream = true; - } - - if close_stream { - self.close()?; } Ok(()) } @@ -308,25 +301,24 @@ impl StreamHandle { Ok(()) } - fn check_self_state(&mut self) -> Result<(), io::Error> { + // Returns Ok(true) only if eof is reached. + fn check_self_state(&mut self) -> io::Result { // if read buf is empty and state is close, return close error if self.read_buf.is_empty() { match self.state { - StreamState::RemoteClosing => { + StreamState::RemoteClosing | StreamState::Closed => { debug!("closed(EOF)"); - let _ignore = self.send_close(); - Err(io::ErrorKind::UnexpectedEof.into()) + // an empty read indicates that EOF is reached. + Ok(true) } StreamState::Reset => { debug!("connection reset"); - let _ignore = self.send_close(); Err(io::ErrorKind::ConnectionReset.into()) } - StreamState::Closed => Err(io::ErrorKind::BrokenPipe.into()), - _ => Ok(()), + _ => Ok(false), } } else { - Ok(()) + Ok(false) } } } @@ -337,7 +329,9 @@ impl AsyncRead for StreamHandle { cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.check_self_state()?; + if self.check_self_state()? { + return Poll::Ready(Ok(())); + } if let Err(e) = self.recv_frames(cx) { match e { @@ -350,7 +344,9 @@ impl AsyncRead for StreamHandle { } } - self.check_self_state()?; + if self.check_self_state()? { + return Poll::Ready(Ok(())); + } let n = ::std::cmp::min(buf.remaining(), self.read_buf.len()); trace!( @@ -367,11 +363,8 @@ impl AsyncRead for StreamHandle { buf.put_slice(&b); match self.state { - StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => (), - StreamState::LocalClosing => { - if self.close().is_err() { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())); - } + StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => { + debug!("this branch should be unreachable") } _ => { if self.send_window_update().is_err() { @@ -391,9 +384,7 @@ impl AsyncWrite for StreamHandle { buf: &[u8], ) -> Poll> { match self.state { - StreamState::RemoteClosing | StreamState::Reset => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) - } + StreamState::Reset => return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), StreamState::LocalClosing | StreamState::Closed => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::BrokenPipe, @@ -446,18 +437,27 @@ impl AsyncWrite for StreamHandle { impl Drop for StreamHandle { fn drop(&mut self) { if !self.unbound_event_sender.is_closed() && self.state != StreamState::Closed { - let event = StreamEvent::Closed(self.id); - // LocalClosing means that local have sent Fin to the remote and waiting for a response. - // if not, we should send Rst first - if self.state != StreamState::LocalClosing { - let mut flags = self.get_flags(); - flags.add(Flag::Rst); - let frame = Frame::new_window_update(flags, self.id, 0); - let rst_event = StreamEvent::Frame(frame); - - // Always successful unless the session is dropped - let _ignore = self.unbound_event_sender.unbounded_send(rst_event); + match self.state { + // LocalClosing means that local have sent Fin to the remote and waiting for a response. + StreamState::LocalClosing | StreamState::Reset => (), + // if not, we should send Rst first + StreamState::Established + | StreamState::Init + | StreamState::RemoteClosing + | StreamState::SynReceived + | StreamState::SynSent => { + let mut flags = self.get_flags(); + flags.add(Flag::Rst); + let frame = Frame::new_window_update(flags, self.id, 0); + let rst_event = StreamEvent::Frame(frame); + + // Always successful unless the session is dropped + let _ignore = self.unbound_event_sender.unbounded_send(rst_event); + } + StreamState::Closed => unreachable!(), } + + let event = StreamEvent::Closed(self.id); let _ignore = self.unbound_event_sender.unbounded_send(event); } } @@ -560,10 +560,13 @@ mod test { // try poll stream handle, then it will recv RST frame and set self state to reset assert_eq!( stream.read(&mut b).await.unwrap_err().kind(), - ErrorKind::BrokenPipe + ErrorKind::ConnectionReset ); + assert_eq!(stream.state, StreamState::Reset); + drop(stream); + let event = unbound_receiver.next().await.unwrap(); match event { StreamEvent::Closed(_) => (), @@ -681,4 +684,77 @@ mod test { } }); } + + #[test] + fn test_read_with_half_close() { + let rt = rt(); + rt.block_on(async { + let (mut frame_sender, frame_receiver) = channel(2); + let (unbound_sender, _unbound_receiver) = unbounded(); + let mut stream = StreamHandle::new( + 0, + unbound_sender, + frame_receiver, + StreamState::Init, + INITIAL_STREAM_WINDOW, + ); + + stream.shutdown().await.unwrap(); + + assert_eq!(stream.state, StreamState::LocalClosing); + + let flags = Flags::from(Flag::Syn); + let frame = Frame::new_data(flags, 0, BytesMut::from("1234")); + frame_sender.send(frame).await.unwrap(); + let mut b = [0; 1024]; + + assert_eq!(stream.read(&mut b).await.unwrap(), 4); + assert_eq!(&b[..4], b"1234"); + + assert_eq!(stream.state, StreamState::LocalClosing); + }); + } + + #[test] + fn test_write_with_half_close() { + let rt = rt(); + rt.block_on(async { + let (mut frame_sender, frame_receiver) = channel(2); + let (unbound_sender, mut unbound_receiver) = unbounded(); + let mut stream = StreamHandle::new( + 0, + unbound_sender, + frame_receiver, + StreamState::Init, + INITIAL_STREAM_WINDOW, + ); + + let flags = Flags::from(Flag::Fin); + let frame = Frame::new_window_update(flags, 0, 0); + frame_sender.send(frame).await.unwrap(); + let mut b = [0; 1024]; + + assert_eq!(stream.read(&mut b).await.unwrap(), 0); + assert_eq!(stream.state, StreamState::RemoteClosing); + + const TEXT: &[u8] = b"testtext"; + + let jh = tokio::spawn(tokio::time::timeout(std::time::Duration::from_secs(4), async move { + loop { + match unbound_receiver.try_next() { + Ok(Some(ref event)) if matches!(event, StreamEvent::Frame(frame) if frame.length() == TEXT.len() as u32) => break, + Err(_) => (), + _ => panic!("must be frame with written text"), + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + })); + + stream.write_all(TEXT).await.unwrap(); + + jh.await.unwrap().expect("not tiemout"); + + assert_eq!(stream.state, StreamState::RemoteClosing); + }); + } }