From ac7cc4f799f2d4fa19d41d11f496113f252020b2 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Wed, 18 Dec 2024 00:50:01 +0530 Subject: [PATCH] Add a test for decompression exceeding max receive message size --- test/compressor_test.go | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/compressor_test.go b/test/compressor_test.go index 0495a06f0968..eb7ab7f1837e 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -785,3 +785,64 @@ func (s) TestGzipBadChecksum(t *testing.T) { t.Errorf("ss.Client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum) } } + +// fakeCompressor returns a messages of a configured size, irrespective of the +// input. +type fakeCompressor struct { + decompressedMessageSize int +} + +func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) { + return nopWriteCloser{w}, nil +} + +func (f *fakeCompressor) Decompress(io.Reader) (io.Reader, error) { + return bytes.NewReader(make([]byte, f.decompressedMessageSize)), nil +} + +func (f *fakeCompressor) Name() string { + return "fake-compressor" +} + +type nopWriteCloser struct { + io.Writer +} + +func (nopWriteCloser) Close() error { + return nil +} + +// TestDecompressionExceedsMaxMessageSize uses a fake compressor that produces +// messages of size 100 bytes on decompression. A server is started with the +// max receive message size restricted to 99 bytes. The test verifies that the +// client receives a ResourceExhausted response from the server. +func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) { + ss := &stubserver.StubServer{ + UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + } + messageLen := 100 + encoding.RegisterCompressor(&fakeCompressor{decompressedMessageSize: messageLen}) + if err := ss.Start([]grpc.ServerOption{grpc.MaxRecvMsgSize(messageLen - 1)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + p, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(50)) + if err != nil { + t.Fatalf("Unexpected error from newPayload: %v", err) + } + req := &testpb.SimpleRequest{Payload: p} + _, err = ss.Client.UnaryCall(ctx, req, grpc.UseCompressor("fake-compressor")) + if err == nil { + t.Errorf("Client.UnaryCall(%+v) = nil, want %v", req, codes.ResourceExhausted) + } + + if got, want := status.Code(err), codes.ResourceExhausted; got != want { + t.Errorf("Client.UnaryCall(%+v) returned stats %v, want %v", req, got, want) + } +}