From b1eb9cc85df7e7e270810b8de2613951375c1d4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Eriksson?= Date: Mon, 14 Sep 2020 13:51:59 +0200 Subject: [PATCH] Add support for closing write ends of streams When sending payloads of unknown length over a Stream and expecting the server to read it to completion before emitting a response (such as forwarding a byte stream), one difficulty with yamux is communicating that the byte stream's end has been reached. One approach is to introduce a higher-level protocol, but when the byte stream is of unknown size this requires essentially reimplementing large parts of yamux's framing protocol. A simpler solution is to communicate that EOF has been reached. Yamux provides this capability through (*Stream).Close, but it closes both the read and the write ends, which then prevents the client from reading the response from the server. This change introduces a new method, (*Stream).CloseWrite, which only closes the write end of the stream. When encountered on the other end, it sets a flag that the remote's write end has been closed and begins returning EOF from any reads after the receive buffer has been exhausted. --- const.go | 8 +++++ session_test.go | 31 ++++++++++++++++++++ stream.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+) diff --git a/const.go b/const.go index 4f52938..84cf008 100644 --- a/const.go +++ b/const.go @@ -35,6 +35,9 @@ var ( // ErrStreamClosed is returned when using a closed stream ErrStreamClosed = fmt.Errorf("stream closed") + // ErrWriteClosed is returned when using a closed write end of a stream + ErrWriteClosed = fmt.Errorf("write end of stream closed") + // ErrUnexpectedFlag is set when we get an unexpected flag ErrUnexpectedFlag = fmt.Errorf("unexpected flag") @@ -93,6 +96,11 @@ const ( // RST is used to hard close a given stream. flagRST + + // flagCloseWrite is sent to notify the remote end + // that no more data will be written to the stream. + // May be sent with a data payload. + flagCloseWrite ) const ( diff --git a/session_test.go b/session_test.go index 4bbdfde..46a56ef 100644 --- a/session_test.go +++ b/session_test.go @@ -1351,3 +1351,34 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) { wg.Wait() } + +func TestCloseWrite(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + stream2, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream2.Close() + + if _, err := stream.Write([]byte("test")); err != nil { + t.Fatal(err) + } else if err := stream.CloseWrite(); err != nil { + t.Fatal(err) + } + + data, err := ioutil.ReadAll(stream2) + if err != nil { + t.Fatal(err) + } else if !bytes.Equal(data, []byte("test")) { + t.Fatalf("got data %q, want %q", data, "test") + } +} diff --git a/stream.go b/stream.go index e951c22..adb4f62 100644 --- a/stream.go +++ b/stream.go @@ -21,6 +21,14 @@ const ( streamReset ) +type streamFlags uint16 + +const ( + writeCloseFlag streamFlags = 1 << iota + writeCloseFlagSent + readCloseFlag +) + // Stream is used to represent a logical stream // within a session. type Stream struct { @@ -31,6 +39,7 @@ type Stream struct { session *Session state streamState + flags streamFlags stateLock sync.Mutex recvBuf *bytes.Buffer @@ -104,6 +113,15 @@ START: s.stateLock.Unlock() return 0, ErrConnectionReset } + if (s.flags & readCloseFlag) != 0 { + s.recvLock.Lock() + if s.recvBuf == nil || s.recvBuf.Len() == 0 { + s.recvLock.Unlock() + s.stateLock.Unlock() + return 0, io.EOF + } + s.recvLock.Unlock() + } s.stateLock.Unlock() // If there is no data available, block @@ -174,6 +192,10 @@ START: s.stateLock.Unlock() return 0, ErrConnectionReset } + if (s.flags & writeCloseFlag) != 0 { + s.stateLock.Unlock() + return 0, ErrWriteClosed + } s.stateLock.Unlock() // If there is no data available, block @@ -231,6 +253,10 @@ func (s *Stream) sendFlags() uint16 { flags |= flagACK s.state = streamEstablished } + if (s.flags & writeCloseFlag & ^writeCloseFlagSent) != 0 { + flags |= flagCloseWrite + s.flags |= writeCloseFlagSent + } return flags } @@ -321,6 +347,53 @@ SEND_CLOSE: return nil } +// CloseWrite is used to close this side's write end of the stream. +func (s *Stream) CloseWrite() error { + s.stateLock.Lock() + s.flags |= writeCloseFlag + switch s.state { + // Opened means we need to signal a close + case streamSYNSent: + fallthrough + case streamSYNReceived: + fallthrough + case streamEstablished: + goto SEND_CLOSE + + case streamLocalClose: + case streamRemoteClose: + goto SEND_CLOSE + case streamClosed: + case streamReset: + default: + panic("unhandled state") + } + s.stateLock.Unlock() + return nil +SEND_CLOSE: + s.stateLock.Unlock() + s.sendCloseWrite() + s.notifyWaiting() + return nil +} + +// sendCloseWrite is used to send a write close notice +func (s *Stream) sendCloseWrite() error { + s.controlHdrLock.Lock() + defer s.controlHdrLock.Unlock() + + flags := s.sendFlags() + if (flags & flagCloseWrite) == 0 { + // We have already sent it; no need to do so again + return nil + } + s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) + if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + return err + } + return nil +} + // forceClose is used for when the session is exiting func (s *Stream) forceClose() { s.stateLock.Lock() @@ -348,6 +421,10 @@ func (s *Stream) processFlags(flags uint16) error { } s.session.establishStream(s.id) } + if (flags & flagCloseWrite) == flagCloseWrite { + s.flags |= readCloseFlag + s.notifyWaiting() + } if flags&flagFIN == flagFIN { switch s.state { case streamSYNSent: