diff --git a/stats/stats_test.go b/stats/stats_test.go index ec5ffa042f47..3c28642138d4 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -59,8 +59,10 @@ func init() { grpc.EnableTracing = false } -type connCtxKey struct{} -type rpcCtxKey struct{} +type ( + connCtxKey struct{} + rpcCtxKey struct{} +) var ( // For headers sent to server: @@ -81,6 +83,35 @@ var ( } // The id for which the service handler should return error. errorID int32 = 32202 + + // Ensure that Unary RPC server stats events are logged in the correct order. + expectedUnarySequence = []string{ + "ConnStats", + "InHeader", + "Begin", + "InPayload", + "OutHeader", + "OutPayload", + "OutTrailer", + "End", + } + + // Ensure that the sequence of server-side stats events for a Unary RPC + // matches the expected flow. + expectedClientStreamSequence = []string{ + "ConnStats", + "InHeader", + "Begin", + "OutHeader", + "InPayload", + "InPayload", + "InPayload", + "InPayload", + "InPayload", + "OutPayload", + "OutTrailer", + "End", + } ) func idToPayload(id int32) *testpb.Payload { @@ -119,12 +150,25 @@ type testServer struct { testgrpc.UnimplementedTestServiceServer } -func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { +func (s *testServer) UnaryCall( + ctx context.Context, + in *testpb.SimpleRequest, +) (*testpb.SimpleResponse, error) { if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil { - return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want ", testHeaderMetadata, err) + return nil, status.Errorf( + status.Code(err), + "grpc.SendHeader(_, %v) = %v, want ", + testHeaderMetadata, + err, + ) } if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { - return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata, err) + return nil, status.Errorf( + status.Code(err), + "grpc.SetTrailer(_, %v) = %v, want ", + testTrailerMetadata, + err, + ) } if id := payloadToID(in.Payload); id == errorID { @@ -136,7 +180,14 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error { if err := stream.SendHeader(testHeaderMetadata); err != nil { - return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) + return status.Errorf( + status.Code(err), + "%v.SendHeader(%v) = %v, want %v", + stream, + testHeaderMetadata, + err, + nil, + ) } stream.SetTrailer(testTrailerMetadata) for { @@ -159,9 +210,18 @@ func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallSe } } -func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error { +func (s *testServer) StreamingInputCall( + stream testgrpc.TestService_StreamingInputCallServer, +) error { if err := stream.SendHeader(testHeaderMetadata); err != nil { - return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) + return status.Errorf( + status.Code(err), + "%v.SendHeader(%v) = %v, want %v", + stream, + testHeaderMetadata, + err, + nil, + ) } stream.SetTrailer(testTrailerMetadata) for { @@ -180,9 +240,19 @@ func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInp } } -func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error { +func (s *testServer) StreamingOutputCall( + in *testpb.StreamingOutputCallRequest, + stream testgrpc.TestService_StreamingOutputCallServer, +) error { if err := stream.SendHeader(testHeaderMetadata); err != nil { - return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) + return status.Errorf( + status.Code(err), + "%v.SendHeader(%v) = %v, want %v", + stream, + testHeaderMetadata, + err, + nil, + ) } stream.SetTrailer(testTrailerMetadata) @@ -326,7 +396,11 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - resp, err = tc.UnaryCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast)) + resp, err = tc.UnaryCall( + metadata.NewOutgoingContext(tCtx, testMetadata), + req, + grpc.WaitForReady(!c.failfast), + ) return req, resp, err } @@ -339,7 +413,10 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []prot tc := testgrpc.NewTestServiceClient(te.clientConn()) tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast)) + stream, err := tc.FullDuplexCall( + metadata.NewOutgoingContext(tCtx, testMetadata), + grpc.WaitForReady(!c.failfast), + ) if err != nil { return reqs, resps, err } @@ -371,7 +448,9 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []prot return reqs, resps, nil } -func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.StreamingInputCallResponse, error) { +func (te *test) doClientStreamCall( + c *rpcConfig, +) ([]proto.Message, *testpb.StreamingInputCallResponse, error) { var ( reqs []proto.Message resp *testpb.StreamingInputCallResponse @@ -380,7 +459,10 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.Strea tc := testgrpc.NewTestServiceClient(te.clientConn()) tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := tc.StreamingInputCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast)) + stream, err := tc.StreamingInputCall( + metadata.NewOutgoingContext(tCtx, testMetadata), + grpc.WaitForReady(!c.failfast), + ) if err != nil { return reqs, resp, err } @@ -401,7 +483,9 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.Strea return reqs, resp, err } -func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallRequest, []proto.Message, error) { +func (te *test) doServerStreamCall( + c *rpcConfig, +) (*testpb.StreamingOutputCallRequest, []proto.Message, error) { var ( req *testpb.StreamingOutputCallRequest resps []proto.Message @@ -417,7 +501,11 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallReq req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)} tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := tc.StreamingOutputCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast)) + stream, err := tc.StreamingOutputCall( + metadata.NewOutgoingContext(tCtx, testMetadata), + req, + grpc.WaitForReady(!c.failfast), + ) if err != nil { return req, resps, err } @@ -512,7 +600,12 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) { // expected headers keys have the expected header values. for key := range testHeaderMetadata { if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) { - t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key)) + t.Fatalf( + "st.Header[%s] = %v, want %v", + key, + st.Header.Get(key), + testHeaderMetadata.Get(key), + ) } } } else { @@ -636,7 +729,12 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { // expected headers keys have the expected header values. for key := range testMetadata { if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) { - t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key)) + t.Fatalf( + "st.Header[%s] = %v, want %v", + key, + st.Header.Get(key), + testMetadata.Get(key), + ) } } @@ -786,8 +884,14 @@ func checkConnEnd(t *testing.T, d *gotData) { st.IsClient() // TODO remove this. } +type event struct { + eventType string + timestamp time.Time +} + type statshandler struct { mu sync.Mutex + events []event gotRPC []*gotData gotConn []*gotData } @@ -800,13 +904,41 @@ func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) conte return context.WithValue(ctx, rpcCtxKey{}, info) } +// recordEvent records an event in the statshandler along with a timestamp. +func (h *statshandler) recordEvent(eventType string) { + h.mu.Lock() + defer h.mu.Unlock() + h.events = append(h.events, event{eventType: eventType, timestamp: time.Now()}) +} + func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) { + h.recordEvent("ConnStats") + h.mu.Lock() defer h.mu.Unlock() h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s}) } func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) { + switch s.(type) { + case *stats.Begin: + h.recordEvent("Begin") + case *stats.InHeader: + h.recordEvent("InHeader") + case *stats.InPayload: + h.recordEvent("InPayload") + case *stats.OutHeader: + h.recordEvent("OutHeader") + case *stats.OutPayload: + h.recordEvent("OutPayload") + case *stats.InTrailer: + h.recordEvent("InTrailer") + case *stats.OutTrailer: + h.recordEvent("OutTrailer") + case *stats.End: + h.recordEvent("End") + } + h.mu.Lock() defer h.mu.Unlock() h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s}) @@ -825,7 +957,12 @@ func checkConnStats(t *testing.T, got []*gotData) { checkConnEnd(t, got[len(got)-1]) } -func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { +func checkServerStats( + t *testing.T, + got []*gotData, + expect *expectedData, + checkFuncs []func(t *testing.T, d *gotData, e *expectedData), +) { if len(got) != len(checkFuncs) { for i, g := range got { t.Errorf(" - %v, %T", i, g.s) @@ -838,7 +975,12 @@ func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkF } } -func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { +func testServerStats( + t *testing.T, + tc *testConfig, + cc *rpcConfig, + checkFuncs []func(t *testing.T, d *gotData, e *expectedData), +) { h := &statshandler{} te := newTest(t, tc, nil, []stats.Handler{h}) te.startServer(&testServer{}) @@ -927,26 +1069,36 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f } func (s) TestServerStatsUnaryRPC(t *testing.T) { - testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ - checkInHeader, - checkBegin, - checkInPayload, - checkOutHeader, - checkOutPayload, - checkOutTrailer, - checkEnd, - }) + testServerStats( + t, + &testConfig{compress: ""}, + &rpcConfig{success: true, callType: unaryRPC}, + []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + checkOutPayload, + checkOutTrailer, + checkEnd, + }, + ) } func (s) TestServerStatsUnaryRPCError(t *testing.T) { - testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ - checkInHeader, - checkBegin, - checkInPayload, - checkOutHeader, - checkOutTrailer, - checkEnd, - }) + testServerStats( + t, + &testConfig{compress: ""}, + &rpcConfig{success: false, callType: unaryRPC}, + []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + checkOutTrailer, + checkEnd, + }, + ) } func (s) TestServerStatsClientStreamRPC(t *testing.T) { @@ -967,19 +1119,29 @@ func (s) TestServerStatsClientStreamRPC(t *testing.T) { checkOutTrailer, checkEnd, ) - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs) + testServerStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: true, callType: clientStreamRPC}, + checkFuncs, + ) } func (s) TestServerStatsClientStreamRPCError(t *testing.T) { count := 1 - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ - checkInHeader, - checkBegin, - checkOutHeader, - checkInPayload, - checkOutTrailer, - checkEnd, - }) + testServerStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: false, callType: clientStreamRPC}, + []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkOutHeader, + checkInPayload, + checkOutTrailer, + checkEnd, + }, + ) } func (s) TestServerStatsServerStreamRPC(t *testing.T) { @@ -1000,19 +1162,29 @@ func (s) TestServerStatsServerStreamRPC(t *testing.T) { checkOutTrailer, checkEnd, ) - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs) + testServerStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: true, callType: serverStreamRPC}, + checkFuncs, + ) } func (s) TestServerStatsServerStreamRPCError(t *testing.T) { count := 5 - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ - checkInHeader, - checkBegin, - checkInPayload, - checkOutHeader, - checkOutTrailer, - checkEnd, - }) + testServerStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: false, callType: serverStreamRPC}, + []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + checkOutTrailer, + checkEnd, + }, + ) } func (s) TestServerStatsFullDuplexRPC(t *testing.T) { @@ -1033,19 +1205,29 @@ func (s) TestServerStatsFullDuplexRPC(t *testing.T) { checkOutTrailer, checkEnd, ) - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs) + testServerStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, + checkFuncs, + ) } func (s) TestServerStatsFullDuplexRPCError(t *testing.T) { count := 5 - testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ - checkInHeader, - checkBegin, - checkOutHeader, - checkInPayload, - checkOutTrailer, - checkEnd, - }) + testServerStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, + []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkOutHeader, + checkInPayload, + checkOutTrailer, + checkEnd, + }, + ) } type checkFuncWithCount struct { @@ -1053,7 +1235,12 @@ type checkFuncWithCount struct { c int // expected count } -func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) { +func checkClientStats( + t *testing.T, + got []*gotData, + expect *expectedData, + checkFuncs map[int]*checkFuncWithCount, +) { var expectLen int for _, v := range checkFuncs { expectLen += v.c @@ -1138,7 +1325,12 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF } } -func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { +func testClientStats( + t *testing.T, + tc *testConfig, + cc *rpcConfig, + checkFuncs map[int]*checkFuncWithCount, +) { h := &statshandler{} te := newTest(t, tc, []stats.Handler{h}, nil) te.startServer(&testServer{}) @@ -1231,101 +1423,141 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map } func (s) TestClientStatsUnaryRPC(t *testing.T) { - testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, 1}, - inHeader: {checkInHeader, 1}, - inPayload: {checkInPayload, 1}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: ""}, + &rpcConfig{success: true, failfast: false, callType: unaryRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inPayload: {checkInPayload, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsUnaryRPCError(t *testing.T) { - testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, 1}, - inHeader: {checkInHeader, 1}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: ""}, + &rpcConfig{success: false, failfast: false, callType: unaryRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsClientStreamRPC(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - inHeader: {checkInHeader, 1}, - outPayload: {checkOutPayload, count}, - inTrailer: {checkInTrailer, 1}, - inPayload: {checkInPayload, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + inHeader: {checkInHeader, 1}, + outPayload: {checkOutPayload, count}, + inTrailer: {checkInTrailer, 1}, + inPayload: {checkInPayload, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsClientStreamRPCError(t *testing.T) { count := 1 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - inHeader: {checkInHeader, 1}, - outPayload: {checkOutPayload, 1}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + inHeader: {checkInHeader, 1}, + outPayload: {checkOutPayload, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsServerStreamRPC(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, 1}, - inHeader: {checkInHeader, 1}, - inPayload: {checkInPayload, count}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inPayload: {checkInPayload, count}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsServerStreamRPCError(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, 1}, - inHeader: {checkInHeader, 1}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsFullDuplexRPC(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, count}, - inHeader: {checkInHeader, 1}, - inPayload: {checkInPayload, count}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, count}, + inHeader: {checkInHeader, 1}, + inPayload: {checkInPayload, count}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestClientStatsFullDuplexRPCError(t *testing.T) { count := 5 - testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ - begin: {checkBegin, 1}, - outHeader: {checkOutHeader, 1}, - outPayload: {checkOutPayload, 1}, - inHeader: {checkInHeader, 1}, - inTrailer: {checkInTrailer, 1}, - end: {checkEnd, 1}, - }) + testClientStats( + t, + &testConfig{compress: "gzip"}, + &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, + map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }, + ) } func (s) TestTags(t *testing.T) { @@ -1496,7 +1728,9 @@ func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { t.Errorf("DoesNotExistCall should not be a registered method according to server") } if isRegisteredMethod(server, "/unknownService/UnaryCall") { - t.Errorf("/unknownService/UnaryCall should not be a registered method according to server") + t.Errorf( + "/unknownService/UnaryCall should not be a registered method according to server", + ) } wg.Done() return ctx @@ -1519,3 +1753,66 @@ func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { } wg.Wait() } + +// TestServerStatsUnaryRPCEventSequence tests that the sequence of server-side stats +// events for a Unary RPC matches the expected flow. +func (s) TestServerStatsUnaryRPCEventSequence(t *testing.T) { + h := &statshandler{} + te := newTest(t, &testConfig{compress: ""}, nil, []stats.Handler{h}) + te.startServer(&testServer{}) + defer te.tearDown() + + _, _, err := te.doUnaryCall(&rpcConfig{success: true, callType: unaryRPC}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Allow time for events to propagate + time.Sleep(50 * time.Millisecond) + + // Verify sequence + h.mu.Lock() + defer h.mu.Unlock() + verifyEventSequence(t, h.events, expectedUnarySequence) +} + +// TestServerStatsClientStreamEventSequence tests that the sequence of server-side +// stats events for a Client Stream RPC matches the expected flow. +func (s) TestServerStatsClientStreamEventSequence(t *testing.T) { + h := &statshandler{} + te := newTest(t, &testConfig{compress: "gzip"}, nil, []stats.Handler{h}) + te.startServer(&testServer{}) + defer te.tearDown() + + _, _, err := te.doClientStreamCall( + &rpcConfig{count: 5, success: true, callType: clientStreamRPC}, + ) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + h.mu.Lock() + defer h.mu.Unlock() + verifyEventSequence(t, h.events, expectedClientStreamSequence) +} + +// verifyEventSequence verifies that a sequence of recorded events matches +// the expected sequence. +func verifyEventSequence(t *testing.T, got []event, expected []string) { + if len(got) != len(expected) { + t.Fatalf("Event count mismatch. Got: %d, Expected: %d", len(got), len(expected)) + } + + for i, e := range got { + if e.eventType != expected[i] { + t.Errorf( + "Unexpected event at position %d. Got: %s, Expected: %s", + i, + e.eventType, + expected[i], + ) + } + } +} diff --git a/xds/internal/xdsclient/tests/lrs_stream_backoff_test.go b/xds/internal/xdsclient/tests/lrs_stream_backoff_test.go new file mode 100644 index 000000000000..6a728bdcb06b --- /dev/null +++ b/xds/internal/xdsclient/tests/lrs_stream_backoff_test.go @@ -0,0 +1,490 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient_test + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/xds/e2e" + "google.golang.org/grpc/xds/internal/xdsclient/xdsresource" + "google.golang.org/grpc/xds/internal/xdsclient/xdsresource/version" + "google.golang.org/protobuf/testing/protocmp" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3discoverypb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" +) + +// Tests the case where the management server returns an error in the ADS +// streaming RPC. Verifies that the LRS stream is restarted after a backoff +// period, and that the previously requested resources are re-requested on the +// new stream. +func (s) TestLRS_BackoffAfterStreamFailure(t *testing.T) { + // Channels for test state. + streamCloseCh := make(chan struct{}, 1) + resourceRequestCh := make(chan []string, 1) + backoffCh := make(chan struct{}, 1) + // Context with timeout. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Simulate LRS stream error. + streamErr := errors.New("LRS stream error") + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{ + SupportLoadReportingService: true, + OnStreamRequest: func(_ int64, req *v3discoverypb.DiscoveryRequest) error { + t.Logf("Simulated server: Received stream request: %+v\n", req) + if req.GetTypeUrl() == version.V3ListenerURL { + select { + case resourceRequestCh <- req.GetResourceNames(): + case <-ctx.Done(): + } + } + return streamErr + }, + OnStreamClosed: func(int64, *v3corepb.Node) { + t.Log("Simulated server: Stream closed") + select { + case streamCloseCh <- struct{}{}: + case <-ctx.Done(): + } + }, + }) + // Backoff behavior. + streamBackoff := func(v int) time.Duration { + t.Log("Backoff triggered") + select { + case backoffCh <- struct{}{}: + case <-ctx.Done(): + } + return 500 * time.Millisecond + } + // Create xDS client and bootstrap configuration. + nodeID := uuid.New().String() + bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + testutils.CreateBootstrapFileForTesting(t, bc) + client := createXDSClientWithBackoff(t, bc, streamBackoff) + + const listenerName = "listener" + lw := newListenerWatcher() + ldsCancel := xdsresource.WatchListener(client, listenerName, lw) + defer ldsCancel() + + // Verify resource request. + if err := waitForResourceNames(ctx, t, resourceRequestCh, []string{listenerName}); err != nil { + t.Fatal(err) + } + + // Verify that the received stream error is reported to the watcher. + u, err := lw.updateCh.Receive(ctx) + if err != nil { + t.Fatal("Timeout when waiting for an error callback on the listener watcher") + } + gotErr := u.(listenerUpdateErrTuple).err + if !strings.Contains(gotErr.Error(), streamErr.Error()) { + t.Fatalf("Received stream error: %v, wantErr: %v", gotErr, streamErr) + } + + // Verify stream closure. + select { + case <-streamCloseCh: + t.Log("Stream closure observed after error") + case <-ctx.Done(): + t.Fatal("Timeout waiting for LRS stream closure") + } + // Verify backoff signal. + select { + case <-backoffCh: + t.Log("Backoff observed before stream restart") + case <-ctx.Done(): + t.Fatal("Timeout waiting for backoff signal") + } + // Verify re-request. + if err := waitForResourceNames(ctx, t, resourceRequestCh, []string{listenerName}); err != nil { + t.Fatal(err) + } +} + +// Tests the case where a stream breaks because the server goes down. Verifies +// that when the server comes back up, the same resources are re-requested, +// this time with the previously acked version and an empty nonce. +func (s) TestLRS_BackoffAfterBrokenStream(t *testing.T) { + // Channels for verifying different events in the test. + streamCloseCh := make(chan struct{}, 1) // LRS stream is closed. + resourceRequestCh := make(chan []string, 1) // Resource names in the discovery request. + backoffCh := make(chan struct{}, 1) // Backoff after stream failure. + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Simulate LRS stream error. + // streamErr := errors.New("LRS stream error") + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{ + SupportLoadReportingService: true, + OnStreamRequest: func(_ int64, req *v3discoverypb.DiscoveryRequest) error { + if req.GetTypeUrl() == version.V3ListenerURL { + t.Logf("Received LRS request for resources: %v", req.GetResourceNames()) + select { + case resourceRequestCh <- req.GetResourceNames(): + case <-ctx.Done(): + } + } + return errors.New("unsupported TypeURL") + }, + OnStreamClosed: func(int64, *v3corepb.Node) { + t.Log("Simulated server: Stream closed") + select { + case streamCloseCh <- struct{}{}: + case <-ctx.Done(): + } + }, + }) + + // Override the backoff implementation. + streamBackoff := func(v int) time.Duration { + t.Log("Backoff triggered") + select { + case backoffCh <- struct{}{}: + case <-ctx.Done(): + } + return 500 * time.Millisecond + } + + // Create an xDS client with bootstrap pointing to the above server. + nodeID := uuid.New().String() + bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + testutils.CreateBootstrapFileForTesting(t, bc) + client := createXDSClientWithBackoff(t, bc, streamBackoff) + + // Register a watch for load reporting resource. + const resourceName = "load-report" + lw := newListenerWatcher() // Replace this with the correct LRS watcher if available. + lrsCancel := xdsresource.WatchListener(client, resourceName, lw) + defer lrsCancel() + + // Verify the initial resource request. + if err := waitForResourceNames(ctx, t, resourceRequestCh, []string{resourceName}); err != nil { + t.Fatal(err) + } + + // Verify stream closure after an error. + select { + case <-streamCloseCh: + t.Log("Stream closure observed after error") + case <-ctx.Done(): + t.Fatal("Timeout waiting for LRS stream closure") + } + + // Verify backoff signal before restarting the stream. + select { + case <-backoffCh: + t.Log("Backoff observed before stream restart") + case <-ctx.Done(): + t.Fatal("Timeout waiting for backoff signal") + } + + // Verify the resource request is re-sent after stream recovery. + if err := waitForResourceNames(ctx, t, resourceRequestCh, []string{resourceName}); err != nil { + t.Fatal(err) + } +} + +// Tests the case where a stream breaks because the server goes down. Verifies +// that when the server comes back up, the same resources are re-requested, this +// time with the previously acked version and an empty nonce. +func (s) TestLRS_RetriesAfterBrokenStream(t *testing.T) { + // Channels used for verifying different events in the test. + streamRequestCh := make(chan *v3discoverypb.DiscoveryRequest, 1) // Discovery request is received. + streamResponseCh := make(chan *v3discoverypb.DiscoveryResponse, 1) // Discovery response is received. + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Create an xDS management server listening on a local port. + l, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("Failed to create a local listener for the xDS management server: %v", err) + } + lis := testutils.NewRestartableListener(l) + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{ + Listener: lis, + SupportLoadReportingService: true, + // Push the received request on to a channel for the test goroutine to + // verify that it matches expectations. + OnStreamRequest: func(_ int64, req *v3discoverypb.DiscoveryRequest) error { + select { + case streamRequestCh <- req: + case <-ctx.Done(): + } + return nil + }, + // Push the response that the management server is about to send on to a + // channel. The test goroutine to uses this to extract the version and + // nonce, expected on subsequent requests. + OnStreamResponse: func(_ context.Context, _ int64, _ *v3discoverypb.DiscoveryRequest, resp *v3discoverypb.DiscoveryResponse) { + select { + case streamResponseCh <- resp: + case <-ctx.Done(): + } + }, + }) + + // Create a listener resource on the management server. + const listenerName = "load-report" + const routeConfigName = "route-config" + nodeID := uuid.New().String() + resources := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{e2e.DefaultClientListener(listenerName, routeConfigName)}, + SkipValidation: true, + } + if err := mgmtServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Override the backoff implementation to always return 0, to reduce test + // run time. Instead control when the backoff returns by blocking on a + // channel, that the test closes. + backoffCh := make(chan struct{}) + streamBackoff := func(v int) time.Duration { + select { + case backoffCh <- struct{}{}: + case <-ctx.Done(): + } + return 0 + } + + // Create an xDS client with bootstrap pointing to the above server. + bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + testutils.CreateBootstrapFileForTesting(t, bc) + client := createXDSClientWithBackoff(t, bc, streamBackoff) + + // Register a watch for a listener resource. + lw := newListenerWatcher() + ldsCancel := xdsresource.WatchListener(client, listenerName, lw) + defer ldsCancel() + + // Verify that the initial discovery request matches expectation. + var gotReq *v3discoverypb.DiscoveryRequest + select { + case gotReq = <-streamRequestCh: + case <-ctx.Done(): + t.Fatalf("Timeout waiting for discovery request on the stream") + } + wantReq := &v3discoverypb.DiscoveryRequest{ + VersionInfo: "", + Node: &v3corepb.Node{ + Id: nodeID, + UserAgentName: "gRPC Go", + UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{"envoy.lb.does_not_support_overprovisioning", "xds.config.resource-in-sotw"}, + }, + ResourceNames: []string{listenerName}, + TypeUrl: "type.googleapis.com/envoy.config.listener.v3.Listener", + ResponseNonce: "", + } + if diff := cmp.Diff(gotReq, wantReq, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected diff in received discovery request, diff (-got, +want):\n%s", diff) + } + + // Capture the version and nonce from the response. + var gotResp *v3discoverypb.DiscoveryResponse + select { + case gotResp = <-streamResponseCh: + case <-ctx.Done(): + t.Fatalf("Timeout waiting for discovery response on the stream") + } + version := gotResp.GetVersionInfo() + nonce := gotResp.GetNonce() + + // Verify that the ACK contains the appropriate version and nonce. + wantReq.VersionInfo = version + wantReq.ResponseNonce = nonce + select { + case gotReq = <-streamRequestCh: + case <-ctx.Done(): + t.Fatalf("Timeout waiting for the discovery request ACK on the stream") + } + if diff := cmp.Diff(gotReq, wantReq, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected diff in received discovery request, diff (-got, +want):\n%s", diff) + } + + // Verify the update received by the watcher. + wantUpdate := listenerUpdateErrTuple{ + update: xdsresource.ListenerUpdate{ + RouteConfigName: routeConfigName, + HTTPFilters: []xdsresource.HTTPFilter{{Name: "router"}}, + }, + } + if err := verifyListenerUpdate(ctx, lw.updateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Bring down the management server to simulate a broken stream. + lis.Stop() + + // Verify that the error callback on the watcher is not invoked. + verifyNoListenerUpdate(ctx, lw.updateCh) + + // Wait for backoff to kick in, and unblock the first backoff attempt. + select { + case <-backoffCh: + case <-ctx.Done(): + t.Fatal("Timeout waiting for stream backoff") + } + + // Bring up the management server. The test does not have prcecise control + // over when new streams to the management server will start succeeding. The + // ADS stream implementation will backoff as many times as required before + // it can successfully create a new stream. Therefore, we need to receive on + // the backoffCh as many times as required, and unblock the backoff + // implementation. + lis.Restart() + go func() { + for { + select { + case <-backoffCh: + case <-ctx.Done(): + return + } + } + }() + + // Verify that the transport creates a new stream and sends out a new + // request which contains the previously acked version, but an empty nonce. + wantReq.ResponseNonce = "" + select { + case gotReq = <-streamRequestCh: + case <-ctx.Done(): + t.Fatalf("Timeout waiting for the discovery request ACK on the stream") + } + if diff := cmp.Diff(gotReq, wantReq, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected diff in received discovery request, diff (-got, +want):\n%s", diff) + } +} + +// Tests the case where a resource is requested before the a valid ADS stream +// exists. Verifies that the a discovery request is sent out for the previously +// requested resource once a valid stream is created. +func (s) TestLRS_ResourceRequestedBeforeStreamCreation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Channels for verifying different events in the test. + streamRequestCh := make(chan *v3discoverypb.DiscoveryRequest, 1) + + // Create an xDS management server listening on a local port. + l, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("Failed to create a local listener: %v", err) + } + defer l.Close() + + lis := testutils.NewRestartableListener(l) + defer lis.Stop() + + streamErr := errors.New("LRS stream error") + + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{ + Listener: lis, + SupportLoadReportingService: true, + OnStreamRequest: func(id int64, req *v3discoverypb.DiscoveryRequest) error { + // Capture only LoadStats requests. + if req.GetTypeUrl() == version.V3ListenerURL { + select { + case streamRequestCh <- req: + default: + } + } + return streamErr + }, + }) + // defer mgmtServer.Stop() + + // Bring down the management server before creating the transport. + lis.Stop() + + // Override backoff to minimize test time. + backoffCh := make(chan struct{}, 1) + unblockBackoffCh := make(chan struct{}) + streamBackoff := func(v int) time.Duration { + select { + case backoffCh <- struct{}{}: + default: + } + <-unblockBackoffCh + return 0 + } + + // Create an xDS client with bootstrap pointing to the above server. + nodeID := uuid.New().String() + bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + testutils.CreateBootstrapFileForTesting(t, bc) + client := createXDSClientWithBackoff(t, bc, streamBackoff) + if client == nil { + t.Fatalf("Failed to create xDS client") + } + + // Register a listener watch for the "load-report" resource. + const listenerName = "load-report" + lw := newListenerWatcher() + ldsCancel := xdsresource.WatchListener(client, listenerName, lw) + defer ldsCancel() + + // Wait for backoff to kick in. + select { + case <-backoffCh: + case <-ctx.Done(): + t.Fatal("Timeout waiting for stream backoff") + } + + // Bring up the connection to the management server and unblock the backoff. + lis.Restart() + close(unblockBackoffCh) + + // Verify that the initial discovery request matches expectations. + var gotReq *v3discoverypb.DiscoveryRequest + select { + case gotReq = <-streamRequestCh: + case <-ctx.Done(): + t.Fatalf("Timeout waiting for discovery request on the stream") + } + wantReq := &v3discoverypb.DiscoveryRequest{ + VersionInfo: "", + Node: &v3corepb.Node{ + Id: nodeID, + UserAgentName: "gRPC Go", + UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{"envoy.lb.does_not_support_overprovisioning", "xds.config.resource-in-sotw"}, + }, + ResourceNames: []string{listenerName}, + TypeUrl: "type.googleapis.com/envoy.config.listener.v3.Listener", + ResponseNonce: "", + } + if diff := cmp.Diff(gotReq, wantReq, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected diff in received discovery request, diff (-got, +want):\n%s", diff) + } +}