diff --git a/cmd/tools/config/generate.go b/cmd/tools/config/generate.go index 218d816d9e6d8..2c709539ec450 100644 --- a/cmd/tools/config/generate.go +++ b/cmd/tools/config/generate.go @@ -297,6 +297,10 @@ func WriteYaml(w io.Writer) { name: "tls", header: "\n# Configure the proxy tls enable.", }, + { + name: "internaltls", + header: "\n# Configure the node-tls enable.", + }, { name: "common", }, diff --git a/configs/milvus.yaml b/configs/milvus.yaml index a71c8447d6a60..7cbd7f89a4c56 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -784,6 +784,12 @@ tls: serverKeyPath: configs/cert/server.key caPemPath: configs/cert/ca.pem +# Configure the node-tls enable. +internaltls: + serverPemPath: configs/cert/server.pem + serverKeyPath: configs/cert/server.key + caPemPath: configs/cert/ca.pem + common: defaultPartitionName: _default # Name of the default partition when a collection is created defaultIndexName: _default_idx # Name of the index when it is created with name unspecified @@ -839,6 +845,7 @@ common: privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition # Collection level readwrite privileges admin: privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition,CreateAlias,DropAlias # Collection level admin privileges + internaltlsEnabled: false tlsMode: 0 session: ttl: 30 # ttl value when session granting a lease to register service diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index df5fecb1af4f6..bb095cdae30b0 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -71,6 +72,15 @@ func NewClient(ctx context.Context) (types.DataCoordClient, error) { client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetSession(sess) + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "Datacoord") + if err != nil { + log.Error("Failed to create cert pool for Datacoord client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 998bb21106ab7..ee17f8c0d3a03 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -174,7 +174,7 @@ func (s *Server) startGrpcLoop() { Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - s.grpcServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -201,7 +201,11 @@ func (s *Server) startGrpcLoop() { }), streamingserviceinterceptor.NewStreamingServiceStreamServerInterceptor(), )), - grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) + grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("DataCoord")) + s.grpcServer = grpc.NewServer(grpcOpts...) indexpb.RegisterIndexCoordServer(s.grpcServer, s) datapb.RegisterDataCoordServer(s.grpcServer, s) // register the streaming coord grpc service. diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 67d5081a19e8c..5859e7ee33e5c 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" @@ -72,6 +73,15 @@ func NewClient(ctx context.Context, addr string, serverID int64) (types.DataNode client.grpcClient.SetNodeID(serverID) client.grpcClient.SetSession(sess) + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "DataNode") + if err != nil { + log.Error("Failed to create cert pool for DataNode client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 5e4ae6f0095e9..2c08267b7bfe2 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -129,7 +129,7 @@ func (s *Server) startGrpcLoop() { Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - s.grpcServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -154,7 +154,11 @@ func (s *Server) startGrpcLoop() { return s.serverID.Load() }), )), - grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) + grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("DataNode")) + s.grpcServer = grpc.NewServer(grpcOpts...) datapb.RegisterDataNodeServer(s.grpcServer, s) ctx, cancel := context.WithCancel(s.ctx) diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index cb301bd7d61ef..7387bdb1385e3 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/internal/types" @@ -72,6 +73,15 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool) if encryption { client.grpcClient.EnableEncryption() } + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "IndexNode") + if err != nil { + log.Error("Failed to create cert pool for IndexNode client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index 403343ee907c4..8108615450441 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -114,7 +114,7 @@ func (s *Server) startGrpcLoop() { Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - s.grpcServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -139,7 +139,11 @@ func (s *Server) startGrpcLoop() { return s.serverID.Load() }), )), - grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) + grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("IndexNode")) + s.grpcServer = grpc.NewServer(grpcOpts...) workerpb.RegisterIndexNodeServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) if err := s.grpcServer.Serve(s.listener); err != nil { diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index 549cc9671930c..ffbc91ab20ca3 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/types" @@ -69,6 +70,15 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetNodeID(nodeID) client.grpcClient.SetSession(sess) + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "Proxy") + if err != nil { + log.Error("Failed to create cert pool for Proxy client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index a4ac8ae1e10fb..4647cfcc423c3 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -342,7 +342,7 @@ func (s *Server) startInternalGrpc(errChan chan error) { } opts := tracer.GetInterceptorOpts() - s.grpcInternalServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -366,7 +366,12 @@ func (s *Server) startInternalGrpc(errChan chan error) { } return s.serverID.Load() }), - ))) + )), + grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("Proxy")) + s.grpcInternalServer = grpc.NewServer(grpcOpts...) proxypb.RegisterProxyServer(s.grpcInternalServer, s) grpc_health_v1.RegisterHealthServer(s.grpcInternalServer, s) errChan <- nil diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index 97ebaf3cb68f9..867d73e7f6a87 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" @@ -63,6 +64,15 @@ func NewClient(ctx context.Context) (types.QueryCoordClient, error) { client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetSession(sess) + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "QueryCoord") + if err != nil { + log.Error("Failed to create cert pool for QueryCoord client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 25b903c4edc9c..3dbca5bc69363 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -230,7 +230,7 @@ func (s *Server) startGrpcLoop() { ctx, cancel := context.WithCancel(s.loopCtx) defer cancel() - s.grpcServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -256,7 +256,10 @@ func (s *Server) startGrpcLoop() { }), )), grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), - ) + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("QueryCoord")) + s.grpcServer = grpc.NewServer(grpcOpts...) querypb.RegisterQueryCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index abd4b714d7dc1..7dfe4dc8be62a 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" @@ -37,6 +38,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +var Params *paramtable.ComponentParam = paramtable.Get() + // Client is the grpc client of QueryNode. type Client struct { grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient] @@ -70,6 +73,15 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC client.grpcClient.SetNodeID(nodeID) client.grpcClient.SetSession(sess) + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "QueryNode") + if err != nil { + log.Error("Failed to create cert pool for QueryNode client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 31a3074b12309..00e898de70b97 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -176,7 +176,7 @@ func (s *Server) startGrpcLoop() { Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - s.grpcServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -204,7 +204,10 @@ func (s *Server) startGrpcLoop() { }), )), grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), - ) + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("QueryNode")) + s.grpcServer = grpc.NewServer(grpcOpts...) querypb.RegisterQueryNodeServer(s.grpcServer, s) ctx, cancel := context.WithCancel(s.ctx) diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 0ecaf83c7d63d..6d0c871042366 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" @@ -70,6 +71,15 @@ func NewClient(ctx context.Context) (types.RootCoordClient, error) { client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetSession(sess) + if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() { + client.grpcClient.EnableEncryption() + cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "RootCoord") + if err != nil { + log.Error("Failed to create cert pool for RootCoord client") + return nil, err + } + client.grpcClient.SetInternalTLSCertPool(cp) + } return client, nil } diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index f80f9ece902da..d329a1bc1f2c5 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -278,7 +278,7 @@ func (s *Server) startGrpcLoop() { ctx, cancel := context.WithCancel(s.ctx) defer cancel() - s.grpcServer = grpc.NewServer( + grpcOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()), @@ -303,7 +303,11 @@ func (s *Server) startGrpcLoop() { return s.serverID.Load() }), )), - grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) + grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()), + } + + grpcOpts = append(grpcOpts, utils.EnableInternalTLS("RootCoord")) + s.grpcServer = grpc.NewServer(grpcOpts...) rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) diff --git a/internal/distributed/utils/util.go b/internal/distributed/utils/util.go index f2cc161ead0b1..a61b6ff2b0905 100644 --- a/internal/distributed/utils/util.go +++ b/internal/distributed/utils/util.go @@ -1,9 +1,14 @@ package utils import ( + "crypto/x509" + "os" "time" + "github.com/cockroachdb/errors" + "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -30,3 +35,47 @@ func GracefulStopGRPCServer(s *grpc.Server) { <-ch } } + +func getTLSCreds(certFile string, keyFile string, nodeType string) credentials.TransportCredentials { + log.Info("TLS Server PEM Path", zap.String("path", certFile)) + log.Info("TLS Server Key Path", zap.String("path", keyFile)) + creds, err := credentials.NewServerTLSFromFile(certFile, keyFile) + if err != nil { + log.Warn(nodeType+" can't create creds", zap.Error(err)) + log.Warn(nodeType+" can't create creds", zap.Error(err)) + } + return creds +} + +func EnableInternalTLS(NodeType string) grpc.ServerOption { + var Params *paramtable.ComponentParam = paramtable.Get() + certFile := Params.InternalTLSCfg.InternalTLSServerPemPath.GetValue() + keyFile := Params.InternalTLSCfg.InternalTLSServerKeyPath.GetValue() + internaltlsEnabled := Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() + + log.Info("Internal TLS Enabled", zap.Bool("value", internaltlsEnabled)) + + if internaltlsEnabled { + creds := getTLSCreds(certFile, keyFile, NodeType) + return grpc.Creds(creds) + } + return grpc.Creds(nil) +} + +func CreateCertPoolforClient(caFile string, nodeType string) (*x509.CertPool, error) { + log.Info("Creating cert pool for " + nodeType) + log.Info("Cert file path:", zap.String("caFile", caFile)) + certPool := x509.NewCertPool() + + b, err := os.ReadFile(caFile) + if err != nil { + log.Error("Error reading cert file in client", zap.Error(err)) + return nil, err + } + + if !certPool.AppendCertsFromPEM(b) { + log.Error("credentials: failed to append certificates") + return nil, errors.New("failed to append certificates") // Cert pool is invalid, return nil and the error + } + return certPool, err +} diff --git a/internal/mocks/mock_grpc_client.go b/internal/mocks/mock_grpc_client.go index e47fa2bf400ff..4bd3c96f7de2f 100644 --- a/internal/mocks/mock_grpc_client.go +++ b/internal/mocks/mock_grpc_client.go @@ -12,6 +12,8 @@ import ( mock "github.com/stretchr/testify/mock" sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil" + + x509 "crypto/x509" ) // MockGrpcClient is an autogenerated mock type for the GrpcClient type @@ -325,6 +327,39 @@ func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (s return _c } +// SetInternalTLSCertPool provides a mock function with given fields: cp +func (_m *MockGrpcClient[T]) SetInternalTLSCertPool(cp *x509.CertPool) { + _m.Called(cp) +} + +// MockGrpcClient_SetInternalTLSCertPool_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetInternalTLSCertPool' +type MockGrpcClient_SetInternalTLSCertPool_Call[T grpcclient.GrpcComponent] struct { + *mock.Call +} + +// SetInternalTLSCertPool is a helper method to define mock.On call +// - cp *x509.CertPool +func (_e *MockGrpcClient_Expecter[T]) SetInternalTLSCertPool(cp interface{}) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { + return &MockGrpcClient_SetInternalTLSCertPool_Call[T]{Call: _e.mock.On("SetInternalTLSCertPool", cp)} +} + +func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Run(run func(cp *x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*x509.CertPool)) + }) + return _c +} + +func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Return() *MockGrpcClient_SetInternalTLSCertPool_Call[T] { + _c.Call.Return() + return _c +} + +func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] { + _c.Call.Return(run) + return _c +} + // SetNewGrpcClientFunc provides a mock function with given fields: _a0 func (_m *MockGrpcClient[T]) SetNewGrpcClientFunc(_a0 func(*grpc.ClientConn) T) { _m.Called(_a0) diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index dd9e805da5e31..8927b44d207b1 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -19,6 +19,7 @@ package grpcclient import ( "context" "crypto/tls" + "crypto/x509" "strings" "sync" "time" @@ -84,6 +85,7 @@ type GrpcClient[T GrpcComponent] interface { GetRole() string SetGetAddrFunc(func() (string, error)) EnableEncryption() + SetInternalTLSCertPool(cp *x509.CertPool) SetNewGrpcClientFunc(func(cc *grpc.ClientConn) T) ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error) Call(ctx context.Context, caller func(client T) (any, error)) (any, error) @@ -101,9 +103,10 @@ type ClientBase[T interface { newGrpcClient func(cc *grpc.ClientConn) T // grpcClient T - grpcClient *clientConnWrapper[T] - encryption bool - addr atomic.String + grpcClient *clientConnWrapper[T] + encryption bool + cpInternalTLS *x509.CertPool + addr atomic.String // conn *grpc.ClientConn grpcClientMtx sync.RWMutex role string @@ -187,6 +190,10 @@ func (c *ClientBase[T]) EnableEncryption() { c.encryption = true } +func (c *ClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) { + c.cpInternalTLS = cp +} + // SetNewGrpcClientFunc sets newGrpcClient of client func (c *ClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { c.newGrpcClient = f @@ -257,11 +264,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { compress = Zstd } if c.encryption { + log.Debug("Running in internalTLS mode with encryption enabled") conn, err = grpc.DialContext( dialContext, addr, // #nosec G402 - grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})), + grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: c.cpInternalTLS})), grpc.WithBlock(), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize), diff --git a/internal/util/mock/grpcclient.go b/internal/util/mock/grpcclient.go index b466f097c3759..5ffa9d8bd7a2e 100644 --- a/internal/util/mock/grpcclient.go +++ b/internal/util/mock/grpcclient.go @@ -18,6 +18,7 @@ package mock import ( "context" + "crypto/x509" "fmt" "sync" @@ -37,6 +38,7 @@ type GRPCClientBase[T any] struct { newGrpcClient func(cc *grpc.ClientConn) T grpcClient T + cpInternalTLS *x509.CertPool conn *grpc.ClientConn grpcClientMtx sync.RWMutex GetGrpcClientErr error @@ -60,6 +62,10 @@ func (c *GRPCClientBase[T]) SetRole(role string) { func (c *GRPCClientBase[T]) EnableEncryption() { } +func (c *GRPCClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) { + c.cpInternalTLS = cp +} + func (c *GRPCClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { c.newGrpcClient = f } diff --git a/pkg/util/paramtable/base_table.go b/pkg/util/paramtable/base_table.go index 79f3e2c82aa5d..3d3bbd38b4e6c 100644 --- a/pkg/util/paramtable/base_table.go +++ b/pkg/util/paramtable/base_table.go @@ -64,9 +64,10 @@ func globalConfigPrefixs() []string { return []string{"metastore", "localStorage", "etcd", "tikv", "minio", "pulsar", "kafka", "rocksmq", "log", "grpc", "common", "quotaAndLimits", "trace"} } -// support read "milvus.yaml", "default.yaml", "user.yaml" as this order. -// order: milvus.yaml < default.yaml < user.yaml, do not change the order below -var defaultYaml = []string{"milvus.yaml", "default.yaml", "user.yaml"} +// support read "milvus.yaml", "_test.yaml", "default.yaml", "user.yaml" as this order. +// order: milvus.yaml < _test.yaml < default.yaml < user.yaml, do not change the order below. +// Use _test.yaml only for test related purpose. +var defaultYaml = []string{"milvus.yaml", "_test.yaml", "default.yaml", "user.yaml"} // BaseTable the basics of paramtable type BaseTable struct { diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index a2c3437e85f35..55bb1a1894d02 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -83,6 +83,8 @@ type ComponentParam struct { RbacConfig rbacConfig StreamingCfg streamingConfig + InternalTLSCfg InternalTLSConfig + RootCoordGrpcServerCfg GrpcServerConfig ProxyGrpcServerCfg GrpcServerConfig QueryCoordGrpcServerCfg GrpcServerConfig @@ -139,6 +141,8 @@ func (p *ComponentParam) init(bt *BaseTable) { p.GpuConfig.init(bt) p.KnowhereConfig.init(bt) + p.InternalTLSCfg.Init(bt) + p.RootCoordGrpcServerCfg.Init("rootCoord", bt) p.ProxyGrpcServerCfg.Init("proxy", bt) p.ProxyGrpcServerCfg.InternalPort.Export = true diff --git a/pkg/util/paramtable/grpc_param.go b/pkg/util/paramtable/grpc_param.go index f2afef49007ec..53e394c07ccb8 100644 --- a/pkg/util/paramtable/grpc_param.go +++ b/pkg/util/paramtable/grpc_param.go @@ -535,3 +535,41 @@ func (p *GrpcClientConfig) GetDefaultRetryPolicy() map[string]interface{} { "backoffMultiplier": p.BackoffMultiplier.GetAsFloat(), } } + +type InternalTLSConfig struct { + InternalTLSEnabled ParamItem `refreshable:"false"` + InternalTLSServerPemPath ParamItem `refreshable:"false"` + InternalTLSServerKeyPath ParamItem `refreshable:"false"` + InternalTLSCaPemPath ParamItem `refreshable:"false"` +} + +func (p *InternalTLSConfig) Init(base *BaseTable) { + p.InternalTLSEnabled = ParamItem{ + Key: "common.security.internaltlsEnabled", + Version: "2.0.0", + DefaultValue: "false", + Export: true, + } + p.InternalTLSEnabled.Init(base.mgr) + + p.InternalTLSServerPemPath = ParamItem{ + Key: "internaltls.serverPemPath", + Version: "2.0.0", + Export: true, + } + p.InternalTLSServerPemPath.Init(base.mgr) + + p.InternalTLSServerKeyPath = ParamItem{ + Key: "internaltls.serverKeyPath", + Version: "2.0.0", + Export: true, + } + p.InternalTLSServerKeyPath.Init(base.mgr) + + p.InternalTLSCaPemPath = ParamItem{ + Key: "internaltls.caPemPath", + Version: "2.0.0", + Export: true, + } + p.InternalTLSCaPemPath.Init(base.mgr) +} diff --git a/pkg/util/paramtable/grpc_param_test.go b/pkg/util/paramtable/grpc_param_test.go index d1970bec8a14a..bf07dfeaa98e7 100644 --- a/pkg/util/paramtable/grpc_param_test.go +++ b/pkg/util/paramtable/grpc_param_test.go @@ -178,3 +178,19 @@ func TestGrpcClientParams(t *testing.T) { assert.Equal(t, clientConfig.ServerKeyPath.GetValue(), "/key") assert.Equal(t, clientConfig.CaPemPath.GetValue(), "/ca") } + +func TestInternalTLSParams(t *testing.T) { + base := ComponentParam{} + base.Init(NewBaseTable(SkipRemote(true))) + var internalTLSCfg InternalTLSConfig + internalTLSCfg.Init(base.baseTable) + + base.Save("common.security.internalTlsEnabled", "true") + base.Save("internaltls.serverPemPath", "/pem") + base.Save("internaltls.serverKeyPath", "/key") + base.Save("internaltls.caPemPath", "/ca") + assert.Equal(t, internalTLSCfg.InternalTLSEnabled.GetAsBool(), true) + assert.Equal(t, internalTLSCfg.InternalTLSServerPemPath.GetValue(), "/pem") + assert.Equal(t, internalTLSCfg.InternalTLSServerKeyPath.GetValue(), "/key") + assert.Equal(t, internalTLSCfg.InternalTLSCaPemPath.GetValue(), "/ca") +} diff --git a/tests/integration/internaltls/internaltls_test.go b/tests/integration/internaltls/internaltls_test.go new file mode 100644 index 0000000000000..9967051f02ad6 --- /dev/null +++ b/tests/integration/internaltls/internaltls_test.go @@ -0,0 +1,355 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internaltls + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/hookutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +type InternaltlsTestSuit struct { + integration.MiniClusterSuite + + indexType string + metricType string + vecType schemapb.DataType +} + +// Define the content for the configuration YAML file +var configContent = ` +rootCoord: + ip: localhost + +proxy: + ip: localhost + +queryCoord: + ip: localhost + +queryNode: + ip: localhost + +indexNode: + ip: localhost + +dataCoord: + ip: localhost + +dataNode: + ip: localhost + +common: + security: + internaltlsEnabled : true + +internaltls: + serverPemPath: ../../../configs/cert/server.pem + serverKeyPath: ../../../configs/cert/server.key + caPemPath: ../../../configs/cert/ca.pem +` + +const configFilePath = "../../../configs/_test.yaml" + +// CreateConfigFile creates the YAML configuration file for tests +func CreateConfigFile() { + // Write config content to _test.yaml file + err := os.WriteFile(configFilePath, []byte(configContent), 0o600) + if err != nil { + log.Error("Failed to create config file", zap.Error(err)) + } + log.Info(fmt.Sprintf("Config file created: %s", configFilePath)) +} + +func (s *InternaltlsTestSuit) SetupSuite() { + log.Info("Initializing paramtable...") + CreateConfigFile() + paramtable.Init() + log.Info("Setting up EmbedEtcd...") + s.Require().NoError(s.SetupEmbedEtcd()) +} + +func (s *InternaltlsTestSuit) run() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := s.Cluster + + const ( + dim = 128 + dbName = "" + rowNum = 3000 + ) + + collectionName := "TestHelloMilvus" + funcutil.GenRandomStr() + + schema := integration.ConstructSchemaOfVecDataType(collectionName, dim, true, s.vecType) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason())) + } + s.Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success) + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + var fVecColumn *schemapb.FieldData + if s.vecType == schemapb.DataType_SparseFloatVector { + fVecColumn = integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum) + } else { + fVecColumn = integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + } + hashKeys := integration.GenerateHashKeys(rowNum) + insertCheckReport := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + + for { + select { + case <-timeoutCtx.Done(): + s.Fail("insert check timeout") + case report := <-c.Extension.GetReportChan(): + reportInfo := report.(map[string]any) + log.Info("insert report info", zap.Any("reportInfo", reportInfo)) + s.Equal(hookutil.OpTypeInsert, reportInfo[hookutil.OpTypeKey]) + s.NotEqualValues(0, reportInfo[hookutil.RequestDataSizeKey]) + return + } + } + } + go insertCheckReport() + insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + + // flush + flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + for _, segment := range segments { + log.Info("ShowSegments result", zap.String("segment", segment.String())) + } + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: fVecColumn.FieldName, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, s.indexType, s.metricType), + }) + if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode()) + + s.WaitForIndexBuilt(ctx, collectionName, fVecColumn.FieldName) + + // load + loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason())) + } + s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode()) + s.WaitForLoad(ctx, collectionName) + + // search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 10 + topk := 10 + roundDecimal := -1 + + params := integration.GetSearchParams(s.indexType, s.metricType) + searchReq := integration.ConstructSearchRequest("", collectionName, expr, + fVecColumn.FieldName, s.vecType, nil, s.metricType, params, nq, dim, topk, roundDecimal) + + searchCheckReport := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + + for { + select { + case <-timeoutCtx.Done(): + s.Fail("search check timeout") + case report := <-c.Extension.GetReportChan(): + reportInfo := report.(map[string]any) + log.Info("search report info", zap.Any("reportInfo", reportInfo)) + s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey]) + s.NotEqualValues(0, reportInfo[hookutil.ResultDataSizeKey]) + s.NotEqualValues(0, reportInfo[hookutil.RelatedDataSizeKey]) + s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey]) + return + } + } + } + go searchCheckReport() + searchResult, err := c.Proxy.Search(ctx, searchReq) + err = merr.CheckRPCCall(searchResult, err) + s.NoError(err) + + queryCheckReport := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + + for { + select { + case <-timeoutCtx.Done(): + s.Fail("query check timeout") + case report := <-c.Extension.GetReportChan(): + reportInfo := report.(map[string]any) + log.Info("query report info", zap.Any("reportInfo", reportInfo)) + s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey]) + s.NotEqualValues(0, reportInfo[hookutil.ResultDataSizeKey]) + s.NotEqualValues(0, reportInfo[hookutil.RelatedDataSizeKey]) + s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey]) + return + } + } + } + go queryCheckReport() + queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{ + DbName: dbName, + CollectionName: collectionName, + Expr: "", + OutputFields: []string{"count(*)"}, + }) + if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode()) + + deleteCheckReport := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + + for { + select { + case <-timeoutCtx.Done(): + s.Fail("delete check timeout") + case report := <-c.Extension.GetReportChan(): + reportInfo := report.(map[string]any) + log.Info("delete report info", zap.Any("reportInfo", reportInfo)) + s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey]) + s.EqualValues(2, reportInfo[hookutil.SuccessCntKey]) + s.EqualValues(0, reportInfo[hookutil.RelatedCntKey]) + return + } + } + } + go deleteCheckReport() + deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{ + DbName: dbName, + CollectionName: collectionName, + Expr: integration.Int64Field + " in [1, 2]", + }) + if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode()) + + status, err := c.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{ + CollectionName: collectionName, + }) + err = merr.CheckRPCCall(status, err) + s.NoError(err) + + status, err = c.Proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{ + CollectionName: collectionName, + }) + err = merr.CheckRPCCall(status, err) + s.NoError(err) + + log.Info("TestHelloMilvus succeed") +} + +func (s *InternaltlsTestSuit) TestHelloMilvus_basic() { + log.Info("Under test Internal TLS hellomilvus...") + s.indexType = integration.IndexFaissIvfFlat + s.metricType = metric.L2 + s.vecType = schemapb.DataType_FloatVector + s.run() +} + +func (s *InternaltlsTestSuit) TearDownSuite() { + defer func() { + err := os.Remove(configFilePath) + if err != nil { + log.Error("Failed to delete config file:", zap.Error(err)) + return + } + log.Info(fmt.Sprintf("Config file deleted: %s", configFilePath)) + }() + s.MiniClusterSuite.TearDownSuite() +} + +func TestInternalTLS(t *testing.T) { + suite.Run(t, new(InternaltlsTestSuit)) +}