From 60c70a436124c25a5aa3a3b55464d01c22e7ea64 Mon Sep 17 00:00:00 2001 From: Mikhail Mazurskiy <126021+ash2k@users.noreply.github.com> Date: Tue, 12 Nov 2024 11:02:57 +1100 Subject: [PATCH] mem: implement `ReadAll()` for more efficient `io.Reader` consumption (#7653) --- mem/buffer_slice.go | 59 +++++++- mem/buffer_slice_test.go | 308 +++++++++++++++++++++++++++++++++++++++ rpc_util.go | 3 +- 3 files changed, 366 insertions(+), 4 deletions(-) diff --git a/mem/buffer_slice.go b/mem/buffer_slice.go index 228e9c2f20f2..65002e2cc851 100644 --- a/mem/buffer_slice.go +++ b/mem/buffer_slice.go @@ -22,6 +22,11 @@ import ( "io" ) +const ( + // 32 KiB is what io.Copy uses. + readAllBufSize = 32 * 1024 +) + // BufferSlice offers a means to represent data that spans one or more Buffer // instances. A BufferSlice is meant to be immutable after creation, and methods // like Ref create and return copies of the slice. This is why all methods have @@ -219,8 +224,58 @@ func (w *writer) Write(p []byte) (n int, err error) { // NewWriter wraps the given BufferSlice and BufferPool to implement the // io.Writer interface. Every call to Write copies the contents of the given -// buffer into a new Buffer pulled from the given pool and the Buffer is added to -// the given BufferSlice. +// buffer into a new Buffer pulled from the given pool and the Buffer is +// added to the given BufferSlice. func NewWriter(buffers *BufferSlice, pool BufferPool) io.Writer { return &writer{buffers: buffers, pool: pool} } + +// ReadAll reads from r until an error or EOF and returns the data it read. +// A successful call returns err == nil, not err == EOF. Because ReadAll is +// defined to read from src until EOF, it does not treat an EOF from Read +// as an error to be reported. +// +// Important: A failed call returns a non-nil error and may also return +// partially read buffers. It is the responsibility of the caller to free the +// BufferSlice returned, or its memory will not be reused. +func ReadAll(r io.Reader, pool BufferPool) (BufferSlice, error) { + var result BufferSlice + if wt, ok := r.(io.WriterTo); ok { + // This is more optimal since wt knows the size of chunks it wants to + // write and, hence, we can allocate buffers of an optimal size to fit + // them. E.g. might be a single big chunk, and we wouldn't chop it + // into pieces. + w := NewWriter(&result, pool) + _, err := wt.WriteTo(w) + return result, err + } +nextBuffer: + for { + buf := pool.Get(readAllBufSize) + // We asked for 32KiB but may have been given a bigger buffer. + // Use all of it if that's the case. + *buf = (*buf)[:cap(*buf)] + usedCap := 0 + for { + n, err := r.Read((*buf)[usedCap:]) + usedCap += n + if err != nil { + if usedCap == 0 { + // Nothing in this buf, put it back + pool.Put(buf) + } else { + *buf = (*buf)[:usedCap] + result = append(result, NewBuffer(buf, pool)) + } + if err == io.EOF { + err = nil + } + return result, err + } + if len(*buf) == usedCap { + result = append(result, NewBuffer(buf, pool)) + continue nextBuffer + } + } + } +} diff --git a/mem/buffer_slice_test.go b/mem/buffer_slice_test.go index bb4384434ee2..bb9303f0e9e1 100644 --- a/mem/buffer_slice_test.go +++ b/mem/buffer_slice_test.go @@ -20,6 +20,8 @@ package mem_test import ( "bytes" + "crypto/rand" + "errors" "fmt" "io" "testing" @@ -27,6 +29,12 @@ import ( "google.golang.org/grpc/mem" ) +const ( + minReadSize = 1 + // Should match the constant in buffer_slice.go (another package) + readAllBufSize = 32 * 1024 // 32 KiB +) + func newBuffer(data []byte, pool mem.BufferPool) mem.Buffer { return mem.NewBuffer(&data, pool) } @@ -156,6 +164,252 @@ func (s) TestBufferSlice_Reader(t *testing.T) { } } +// TestBufferSlice_ReadAll_Reads exercises ReadAll by allowing it to read +// various combinations of data, empty data, EOF. +func (s) TestBufferSlice_ReadAll_Reads(t *testing.T) { + testcases := []struct { + name string + reads []readStep + wantErr string + wantBufs int + }{ + { + name: "EOF", + reads: []readStep{ + { + err: io.EOF, + }, + }, + }, + { + name: "data,EOF", + reads: []readStep{ + { + n: minReadSize, + }, + { + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "data+EOF", + reads: []readStep{ + { + n: minReadSize, + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "0,data+EOF", + reads: []readStep{ + {}, + { + n: minReadSize, + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "0,data,EOF", + reads: []readStep{ + {}, + { + n: minReadSize, + }, + { + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "data,data+EOF", + reads: []readStep{ + { + n: minReadSize, + }, + { + n: minReadSize, + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "error", + reads: []readStep{ + { + err: errors.New("boom"), + }, + }, + wantErr: "boom", + }, + { + name: "data+error", + reads: []readStep{ + { + n: minReadSize, + err: errors.New("boom"), + }, + }, + wantErr: "boom", + wantBufs: 1, + }, + { + name: "data,data+error", + reads: []readStep{ + { + n: minReadSize, + }, + { + n: minReadSize, + err: errors.New("boom"), + }, + }, + wantErr: "boom", + wantBufs: 1, + }, + { + name: "data,data+EOF - whole buf", + reads: []readStep{ + { + n: minReadSize, + }, + { + n: readAllBufSize - minReadSize, + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "data,data,EOF - whole buf", + reads: []readStep{ + { + n: minReadSize, + }, + { + n: readAllBufSize - minReadSize, + }, + { + err: io.EOF, + }, + }, + wantBufs: 1, + }, + { + name: "data,data,EOF - 2 bufs", + reads: []readStep{ + { + n: readAllBufSize, + }, + { + n: minReadSize, + }, + { + n: readAllBufSize - minReadSize, + }, + { + n: minReadSize, + }, + { + err: io.EOF, + }, + }, + wantBufs: 3, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + pool := &testPool{ + allocated: make(map[*[]byte]struct{}), + } + r := &stepReader{ + reads: tc.reads, + } + data, err := mem.ReadAll(r, pool) + if tc.wantErr != "" { + if err == nil || err.Error() != tc.wantErr { + t.Fatalf("ReadAll() returned err %v, wanted %q", err, tc.wantErr) + } + } else { + if err != nil { + t.Fatal(err) + } + } + gotData := data.Materialize() + if !bytes.Equal(r.read, gotData) { + t.Fatalf("ReadAll() returned data %q, wanted %q", gotData, r.read) + } + if len(data) != tc.wantBufs { + t.Fatalf("ReadAll() returned %d bufs, wanted %d bufs", len(data), tc.wantBufs) + } + // all but last should be full buffers + for i := 0; i < len(data)-1; i++ { + if data[i].Len() != readAllBufSize { + t.Fatalf("ReadAll() returned data length %d, wanted %d", data[i].Len(), readAllBufSize) + } + } + data.Free() + if len(pool.allocated) > 0 { + t.Fatalf("got %d allocated buffers, wanted none", len(pool.allocated)) + } + }) + } +} + +func (s) TestBufferSlice_ReadAll_WriteTo(t *testing.T) { + testcases := []struct { + name string + size int + }{ + { + name: "small", + size: minReadSize, + }, + { + name: "exact size", + size: readAllBufSize, + }, + { + name: "big", + size: readAllBufSize * 3, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + pool := &testPool{ + allocated: make(map[*[]byte]struct{}), + } + buf := make([]byte, tc.size) + _, err := rand.Read(buf) + if err != nil { + t.Fatal(err) + } + r := bytes.NewBuffer(buf) + data, err := mem.ReadAll(r, pool) + if err != nil { + t.Fatal(err) + } + + gotData := data.Materialize() + if !bytes.Equal(buf, gotData) { + t.Fatalf("ReadAll() = %q, wanted %q", gotData, buf) + } + data.Free() + if len(pool.allocated) > 0 { + t.Fatalf("wanted no allocated buffers, got %d", len(pool.allocated)) + } + }) + } +} + func ExampleNewWriter() { var bs mem.BufferSlice pool := mem.DefaultBufferPool() @@ -176,3 +430,57 @@ func ExampleNewWriter() { // Wrote 4 bytes, err: // abcdabcdabcd } + +var ( + _ io.Reader = (*stepReader)(nil) + _ mem.BufferPool = (*testPool)(nil) +) + +// readStep describes what a single stepReader.Read should do - how much data +// to return and what error to return. +type readStep struct { + n int + err error +} + +// stepReader implements io.Reader that reads specified amount of data and/or +// returns the specified error in specified steps. +// The read data is accumulated in the read field. +type stepReader struct { + reads []readStep + read []byte +} + +func (s *stepReader) Read(buf []byte) (int, error) { + if len(s.reads) == 0 { + panic("unexpected Read() call") + } + read := s.reads[0] + s.reads = s.reads[1:] + _, err := rand.Read(buf[:read.n]) + if err != nil { + panic(err) + } + s.read = append(s.read, buf[:read.n]...) + return read.n, read.err +} + +// testPool is an implementation of BufferPool that allows to ensure that: +// - there are matching Put calls for all Get calls. +// - there are no unexpected Put calls. +type testPool struct { + allocated map[*[]byte]struct{} +} + +func (t *testPool) Get(length int) *[]byte { + buf := make([]byte, length) + t.allocated[&buf] = struct{}{} + return &buf +} + +func (t *testPool) Put(buf *[]byte) { + if _, ok := t.allocated[buf]; !ok { + panic("unexpected put") + } + delete(t.allocated, buf) +} diff --git a/rpc_util.go b/rpc_util.go index 033ffdc1c9bf..06c1f1b2855e 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -899,8 +899,7 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMes // } //} - var out mem.BufferSlice - _, err = io.Copy(mem.NewWriter(&out, pool), io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1)) + out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool) if err != nil { out.Free() return nil, 0, err