Skip to content

Commit

Permalink
Add a test for decompression exceeding max receive message size
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Dec 17, 2024
1 parent e8055ea commit ac7cc4f
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions test/compressor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,64 @@ func (s) TestGzipBadChecksum(t *testing.T) {
t.Errorf("ss.Client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum)
}
}

// fakeCompressor returns a messages of a configured size, irrespective of the
// input.
type fakeCompressor struct {
decompressedMessageSize int
}

func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
return nopWriteCloser{w}, nil
}

func (f *fakeCompressor) Decompress(io.Reader) (io.Reader, error) {
return bytes.NewReader(make([]byte, f.decompressedMessageSize)), nil
}

func (f *fakeCompressor) Name() string {
return "fake-compressor"
}

type nopWriteCloser struct {
io.Writer
}

func (nopWriteCloser) Close() error {
return nil
}

// TestDecompressionExceedsMaxMessageSize uses a fake compressor that produces
// messages of size 100 bytes on decompression. A server is started with the
// max receive message size restricted to 99 bytes. The test verifies that the
// client receives a ResourceExhausted response from the server.
func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
ss := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
}
messageLen := 100
encoding.RegisterCompressor(&fakeCompressor{decompressedMessageSize: messageLen})
if err := ss.Start([]grpc.ServerOption{grpc.MaxRecvMsgSize(messageLen - 1)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

p, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(50))
if err != nil {
t.Fatalf("Unexpected error from newPayload: %v", err)
}
req := &testpb.SimpleRequest{Payload: p}
_, err = ss.Client.UnaryCall(ctx, req, grpc.UseCompressor("fake-compressor"))
if err == nil {
t.Errorf("Client.UnaryCall(%+v) = nil, want %v", req, codes.ResourceExhausted)
}

if got, want := status.Code(err), codes.ResourceExhausted; got != want {
t.Errorf("Client.UnaryCall(%+v) returned stats %v, want %v", req, got, want)
}
}

0 comments on commit ac7cc4f

Please sign in to comment.