Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: poll_read: wrong behavior on half close #362

Merged
merged 2 commits into from
Aug 31, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 113 additions & 37 deletions yamux/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ impl StreamHandle {
if flags.contains(Flag::Ack) && self.state == StreamState::SynSent {
self.state = StreamState::SynReceived;
}
let mut close_stream = false;
if flags.contains(Flag::Fin) {
match self.state {
StreamState::Init
Expand All @@ -186,18 +185,12 @@ impl StreamHandle {
}
StreamState::LocalClosing => {
self.state = StreamState::Closed;
close_stream = true;
}
_ => return Err(Error::UnexpectedFlag),
}
}
if flags.contains(Flag::Rst) {
self.state = StreamState::Reset;
close_stream = true;
}

if close_stream {
self.close()?;
}
Ok(())
}
Expand Down Expand Up @@ -308,25 +301,24 @@ impl StreamHandle {
Ok(())
}

fn check_self_state(&mut self) -> Result<(), io::Error> {
// Returns Ok(true) only if eof is reached.
fn check_self_state(&mut self) -> io::Result<bool> {
// if read buf is empty and state is close, return close error
if self.read_buf.is_empty() {
match self.state {
StreamState::RemoteClosing => {
StreamState::RemoteClosing | StreamState::Closed => {
debug!("closed(EOF)");
let _ignore = self.send_close();
Err(io::ErrorKind::UnexpectedEof.into())
// an empty read indicates that EOF is reached.
Ok(true)
}
StreamState::Reset => {
debug!("connection reset");
let _ignore = self.send_close();
Err(io::ErrorKind::ConnectionReset.into())
}
StreamState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
_ => Ok(()),
_ => Ok(false),
}
} else {
Ok(())
Ok(false)
}
}
}
Expand All @@ -337,7 +329,9 @@ impl AsyncRead for StreamHandle {
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.check_self_state()?;
if self.check_self_state()? {
return Poll::Ready(Ok(()));
}

if let Err(e) = self.recv_frames(cx) {
match e {
Expand All @@ -350,7 +344,9 @@ impl AsyncRead for StreamHandle {
}
}

self.check_self_state()?;
if self.check_self_state()? {
return Poll::Ready(Ok(()));
}

let n = ::std::cmp::min(buf.remaining(), self.read_buf.len());
trace!(
Expand All @@ -367,11 +363,8 @@ impl AsyncRead for StreamHandle {

buf.put_slice(&b);
match self.state {
StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => (),
StreamState::LocalClosing => {
if self.close().is_err() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => {
debug!("this branch should be unreachable")
}
_ => {
if self.send_window_update().is_err() {
Expand All @@ -391,9 +384,7 @@ impl AsyncWrite for StreamHandle {
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.state {
StreamState::RemoteClosing | StreamState::Reset => {
driftluo marked this conversation as resolved.
Show resolved Hide resolved
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
StreamState::Reset => return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
StreamState::LocalClosing | StreamState::Closed => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
Expand Down Expand Up @@ -446,18 +437,27 @@ impl AsyncWrite for StreamHandle {
impl Drop for StreamHandle {
fn drop(&mut self) {
if !self.unbound_event_sender.is_closed() && self.state != StreamState::Closed {
let event = StreamEvent::Closed(self.id);
// LocalClosing means that local have sent Fin to the remote and waiting for a response.
// if not, we should send Rst first
if self.state != StreamState::LocalClosing {
let mut flags = self.get_flags();
flags.add(Flag::Rst);
let frame = Frame::new_window_update(flags, self.id, 0);
let rst_event = StreamEvent::Frame(frame);

// Always successful unless the session is dropped
let _ignore = self.unbound_event_sender.unbounded_send(rst_event);
match self.state {
// LocalClosing means that local have sent Fin to the remote and waiting for a response.
StreamState::LocalClosing | StreamState::Reset => (),
// if not, we should send Rst first
StreamState::Established
| StreamState::Init
| StreamState::RemoteClosing
| StreamState::SynReceived
| StreamState::SynSent => {
let mut flags = self.get_flags();
flags.add(Flag::Rst);
let frame = Frame::new_window_update(flags, self.id, 0);
let rst_event = StreamEvent::Frame(frame);

// Always successful unless the session is dropped
let _ignore = self.unbound_event_sender.unbounded_send(rst_event);
}
StreamState::Closed => unreachable!(),
}

let event = StreamEvent::Closed(self.id);
let _ignore = self.unbound_event_sender.unbounded_send(event);
}
}
Expand Down Expand Up @@ -560,10 +560,13 @@ mod test {
// try poll stream handle, then it will recv RST frame and set self state to reset
assert_eq!(
stream.read(&mut b).await.unwrap_err().kind(),
ErrorKind::BrokenPipe
ErrorKind::ConnectionReset
);

assert_eq!(stream.state, StreamState::Reset);

drop(stream);

let event = unbound_receiver.next().await.unwrap();
match event {
StreamEvent::Closed(_) => (),
Expand Down Expand Up @@ -681,4 +684,77 @@ mod test {
}
});
}

#[test]
fn test_read_with_half_close() {
let rt = rt();
rt.block_on(async {
let (mut frame_sender, frame_receiver) = channel(2);
let (unbound_sender, _unbound_receiver) = unbounded();
let mut stream = StreamHandle::new(
0,
unbound_sender,
frame_receiver,
StreamState::Init,
INITIAL_STREAM_WINDOW,
);

stream.shutdown().await.unwrap();

assert_eq!(stream.state, StreamState::LocalClosing);

let flags = Flags::from(Flag::Syn);
let frame = Frame::new_data(flags, 0, BytesMut::from("1234"));
frame_sender.send(frame).await.unwrap();
let mut b = [0; 1024];

assert_eq!(stream.read(&mut b).await.unwrap(), 4);
assert_eq!(&b[..4], b"1234");

assert_eq!(stream.state, StreamState::LocalClosing);
});
}

#[test]
fn test_write_with_half_close() {
let rt = rt();
rt.block_on(async {
let (mut frame_sender, frame_receiver) = channel(2);
let (unbound_sender, mut unbound_receiver) = unbounded();
let mut stream = StreamHandle::new(
0,
unbound_sender,
frame_receiver,
StreamState::Init,
INITIAL_STREAM_WINDOW,
);

let flags = Flags::from(Flag::Fin);
let frame = Frame::new_window_update(flags, 0, 0);
frame_sender.send(frame).await.unwrap();
let mut b = [0; 1024];

assert_eq!(stream.read(&mut b).await.unwrap(), 0);
assert_eq!(stream.state, StreamState::RemoteClosing);

const TEXT: &[u8] = b"testtext";

let jh = tokio::spawn(tokio::time::timeout(std::time::Duration::from_secs(4), async move {
loop {
match unbound_receiver.try_next() {
Ok(Some(ref event)) if matches!(event, StreamEvent::Frame(frame) if frame.length() == TEXT.len() as u32) => break,
Err(_) => (),
_ => panic!("must be frame with written text"),
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
}));

stream.write_all(TEXT).await.unwrap();

jh.await.unwrap().expect("not tiemout");

assert_eq!(stream.state, StreamState::RemoteClosing);
});
}
}