diff --git a/authz/audit/audit_logging_test.go b/authz/audit/audit_logging_test.go index ea84db099608..5db4487e227f 100644 --- a/authz/audit/audit_logging_test.go +++ b/authz/audit/audit_logging_test.go @@ -24,7 +24,6 @@ import ( "crypto/x509" "encoding/json" "io" - "net" "os" "testing" "time" @@ -271,17 +270,13 @@ func (s) TestAuditLogger(t *testing.T) { grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)) defer s.Stop() - testgrpc.RegisterTestServiceServer(s, ss) - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error listening: %v", err) - } - go s.Serve(lis) + ss.S = s + stubserver.StartTestService(t, ss) // Setup gRPC test client with certificates containing a SPIFFE Id. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(clientCreds)) + clientConn, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", ss.Address, err) } defer clientConn.Close() client := testgrpc.NewTestServiceClient(clientConn) diff --git a/authz/grpc_authz_end2end_test.go b/authz/grpc_authz_end2end_test.go index 4e798f7ca3d7..6ddc8dbf78f0 100644 --- a/authz/grpc_authz_end2end_test.go +++ b/authz/grpc_authz_end2end_test.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/testdata" @@ -42,26 +43,6 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) -type testServer struct { - testgrpc.UnimplementedTestServiceServer -} - -func (s *testServer) UnaryCall(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - return &testpb.SimpleResponse{}, nil -} - -func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error { - for { - _, err := stream.Recv() - if err == io.EOF { - return stream.SendAndClose(&testpb.StreamingInputCallResponse{}) - } - if err != nil { - return err - } - } -} - type s struct { grpctest.Tester } @@ -313,17 +294,33 @@ func (s) TestStaticPolicyEnd2End(t *testing.T) { t.Run(name, func(t *testing.T) { // Start a gRPC server with gRPC authz unary and stream server interceptors. i, _ := authz.NewStatic(test.authzPolicy) - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } - go s.Serve(lis) + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&testpb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } + } + }, + S: grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -383,17 +380,22 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnTLSAuthenticatedConnection(t * if err != nil { t.Fatalf("failed to generate credentials: %v", err) } - s := grpc.NewServer( - grpc.Creds(creds), - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } - go s.Serve(lis) + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + S: grpc.NewServer( + grpc.Creds(creds), + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. creds, err = credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") @@ -448,17 +450,22 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnMTLSAuthenticatedConnection(t Certificates: []tls.Certificate{cert}, ClientCAs: certPool, }) - s := grpc.NewServer( - grpc.Creds(creds), - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } - go s.Serve(lis) + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + S: grpc.NewServer( + grpc.Creds(creds), + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. cert, err = tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) @@ -501,19 +508,34 @@ func (s) TestFileWatcherEnd2End(t *testing.T) { i, _ := authz.NewFileWatcher(file, 1*time.Second) defer i.Close() - // Start a gRPC server with gRPC authz unary and stream server interceptors. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } defer lis.Close() - go s.Serve(lis) + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&testpb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } + } + }, + // Start a gRPC server with gRPC authz unary and stream server interceptors. + S: grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -571,18 +593,22 @@ func (s) TestFileWatcher_ValidPolicyRefresh(t *testing.T) { i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) defer i.Close() - // Start a gRPC server with gRPC authz unary server interceptor. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } defer lis.Close() - go s.Serve(lis) + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + // Start a gRPC server with gRPC authz unary server interceptor. + S: grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -619,18 +645,22 @@ func (s) TestFileWatcher_InvalidPolicySkipReload(t *testing.T) { i, _ := authz.NewFileWatcher(file, 20*time.Millisecond) defer i.Close() - // Start a gRPC server with gRPC authz unary server interceptors. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } defer lis.Close() - go s.Serve(lis) + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + // Start a gRPC server with gRPC authz unary server interceptors. + S: grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -670,18 +700,22 @@ func (s) TestFileWatcher_RecoversFromReloadFailure(t *testing.T) { i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) defer i.Close() - // Start a gRPC server with gRPC authz unary server interceptors. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } defer lis.Close() - go s.Serve(lis) + + stub := &stubserver.StubServer{ + Listener: lis, + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + S: grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), + } + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/internal/stubserver/stubserver.go b/internal/stubserver/stubserver.go index 2e404e294bf6..92262af877a6 100644 --- a/internal/stubserver/stubserver.go +++ b/internal/stubserver/stubserver.go @@ -56,9 +56,11 @@ type StubServer struct { testgrpc.TestServiceServer // Customizable implementations of server handlers. - EmptyCallF func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) - UnaryCallF func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) - FullDuplexCallF func(stream testgrpc.TestService_FullDuplexCallServer) error + EmptyCallF func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) + UnaryCallF func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) + FullDuplexCallF func(stream testgrpc.TestService_FullDuplexCallServer) error + StreamingInputCallF func(stream testgrpc.TestService_StreamingInputCallServer) error // Client-Streaming request + StreamingOutputCallF func(req *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error // Server-streaming response // A client connected to this service the test may use. Created in Start(). Client testgrpc.TestServiceClient @@ -101,6 +103,16 @@ func (ss *StubServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallS return ss.FullDuplexCallF(stream) } +// StreamingInputCall is the handler for testpb.StreamingInputCall +func (ss *StubServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error { + return ss.StreamingInputCallF(stream) +} + +// StreamingOutputCall is the handler for testpb.StreamingOutputCall +func (ss *StubServer) StreamingOutputCall(req *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error { + return ss.StreamingOutputCallF(req, stream) +} + // Start starts the server and creates a client connected to it. func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error { if err := ss.StartServer(sopts...); err != nil {