diff --git a/stats/stats.go b/stats/stats.go index 6f20d2d54868..9230a008128a 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -36,7 +36,10 @@ type RPCStats interface { IsClient() bool } -// Begin contains stats when an RPC attempt begins. +// Begin contains stats when an RPC attempt begins. This event is called after +// the InHeader event, as headers must be processed before the RPC lifecycle +// begins. +// // FailFast is only valid if this Begin is from client side. type Begin struct { // Client is true if this Begin is from client side. @@ -98,7 +101,9 @@ func (s *InPayload) IsClient() bool { return s.Client } func (s *InPayload) isRPCStats() {} -// InHeader contains stats when a header is received. +// InHeader contain stats when the header is received. It is the first event in +// the server after receiving the RPC. It is followed by the OutPayload +// server event. type InHeader struct { // Client is true if this InHeader is from client side. Client bool diff --git a/stats/stats_test.go b/stats/stats_test.go index ec5ffa042f47..c3205ba5b21e 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -242,6 +242,8 @@ func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Hand // startServer starts a gRPC server listening. Callers should defer a // call to te.tearDown to clean up. +// +// Uses deprecated opts rpc.(RPCCompressor, RPCDecompressor, WithBlock, Dial) func (te *test) startServer(ts testgrpc.TestServiceServer) { te.testServer = ts lis, err := net.Listen("tcp", "localhost:0") @@ -786,8 +788,13 @@ func checkConnEnd(t *testing.T, d *gotData) { st.IsClient() // TODO remove this. } +type event struct { + eventType string +} + type statshandler struct { mu sync.Mutex + events []event gotRPC []*gotData gotConn []*gotData } @@ -800,13 +807,41 @@ func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) conte return context.WithValue(ctx, rpcCtxKey{}, info) } +// recordEvent records an event in the statshandler along with a timestamp. +func (h *statshandler) recordEvent(eventType string) { + h.mu.Lock() + defer h.mu.Unlock() + h.events = append(h.events, event{eventType: eventType}) +} + func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) { + h.recordEvent("ConnStats") + h.mu.Lock() defer h.mu.Unlock() h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s}) } func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) { + switch s.(type) { + case *stats.Begin: + h.recordEvent("Begin") + case *stats.InHeader: + h.recordEvent("InHeader") + case *stats.InPayload: + h.recordEvent("InPayload") + case *stats.OutHeader: + h.recordEvent("OutHeader") + case *stats.OutPayload: + h.recordEvent("OutPayload") + case *stats.InTrailer: + h.recordEvent("InTrailer") + case *stats.OutTrailer: + h.recordEvent("OutTrailer") + case *stats.End: + h.recordEvent("End") + } + h.mu.Lock() defer h.mu.Unlock() h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s}) @@ -1519,3 +1554,122 @@ func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { } wg.Wait() } + +// TestServerStatsUnaryRPCEventSequence tests that the sequence of server-side stats +// events for a Unary RPC matches the expected flow. +func (s) TestServerStatsUnaryRPCEventSequence(t *testing.T) { + h := &statshandler{} + te := newTest(t, &testConfig{compress: ""}, nil, []stats.Handler{h}) + te.startServer(&testServer{}) + defer te.tearDown() + + _, _, err := te.doUnaryCall(&rpcConfig{success: true, callType: unaryRPC}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // Allow time for events to propagate + time.Sleep(50 * time.Millisecond) + // Verify sequence + h.mu.Lock() + defer h.mu.Unlock() + // To verify if the Unary RPC server stats events are logged in the + // correct order. + wantedUnarySequence := []string{ + "ConnStats", + "InHeader", + "Begin", + "InPayload", + "OutHeader", + "OutPayload", + "OutTrailer", + "End", + } + verifyEventSequence(t, h.events, wantedUnarySequence) +} + +// TestServerStatsClientStreamEventSequence tests that the sequence of server-side +// stats events for a Client Stream RPC matches the expected flow. +func (s) TestServerStatsClientStreamEventSequence(t *testing.T) { + h := &statshandler{} + te := newTest(t, &testConfig{compress: "gzip"}, nil, []stats.Handler{h}) + te.startServer(&testServer{}) + defer te.tearDown() + + _, _, err := te.doClientStreamCall(&rpcConfig{count: 5, success: true, callType: clientStreamRPC}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + h.mu.Lock() + defer h.mu.Unlock() + // To verify if the Client Stream RPC server stats events are logged in the + // correct order. + wantedClientStreamSequence := []string{ + "ConnStats", + "InHeader", + "Begin", + "OutHeader", + "InPayload", + "InPayload", + "InPayload", + "InPayload", + "InPayload", + "OutPayload", + "OutTrailer", + "End", + } + verifyEventSequence(t, h.events, wantedClientStreamSequence) +} + +// TestServerStatsClientStreamEventSequence tests that the sequence of server-side +// stats events for a Client Stream RPC matches the expected flow. +func (s) TestServerStatsServerStreamEventSequence(t *testing.T) { + h := &statshandler{} + te := newTest(t, &testConfig{compress: "gzip"}, nil, []stats.Handler{h}) + te.startServer(&testServer{}) + defer te.tearDown() + + _, _, err := te.doServerStreamCall(&rpcConfig{count: 5, success: true, callType: serverStreamRPC}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + h.mu.Lock() + defer h.mu.Unlock() + + // To verify if the Server Stream RPC server stats events are logged in the + // correct order. + wantedServerStreamSequence := []string{ + "ConnStats", + "InHeader", + "Begin", + "InPayload", + "OutHeader", + "OutPayload", + "OutPayload", + "OutPayload", + "OutPayload", + "OutPayload", + "OutTrailer", + "End", + } + verifyEventSequence(t, h.events, wantedServerStreamSequence) +} + +// verifyEventSequence verifies that a sequence of recorded events matches +// the expected sequence. +func verifyEventSequence(t *testing.T, got []event, expected []string) { + t.Helper() + // Extract event types from `got` for comparison. + gotEventTypes := make([]string, len(got)) + for i, e := range got { + gotEventTypes[i] = e.eventType + } + if !cmp.Equal(gotEventTypes, expected) { + t.Errorf("Event sequence mismatch (-got +expected):\n%s", cmp.Diff(gotEventTypes, expected)) + } +}