diff --git a/cmd/tools/config/generate.go b/cmd/tools/config/generate.go index 218d816d9e6d8..2c709539ec450 100644 --- a/cmd/tools/config/generate.go +++ b/cmd/tools/config/generate.go @@ -297,6 +297,10 @@ func WriteYaml(w io.Writer) { name: "tls", header: "\n# Configure the proxy tls enable.", }, + { + name: "internaltls", + header: "\n# Configure the node-tls enable.", + }, { name: "common", }, diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 245e3cf933e8b..08165b59f6997 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -783,10 +783,11 @@ tls: serverKeyPath: configs/cert/server.key caPemPath: configs/cert/ca.pem +# Configure the node-tls enable. internaltls: - serverPemPath: #path to server.pem - serverKeyPath: #path to server.key - caPemPath: #path to ca.key + serverPemPath: configs/cert/server.pem + serverKeyPath: configs/cert/server.key + caPemPath: configs/cert/ca.pem common: defaultPartitionName: _default # Name of the default partition when a collection is created @@ -843,8 +844,8 @@ common: privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition # Collection level readwrite privileges admin: privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition,CreateAlias,DropAlias # Collection level admin privileges + internaltlsEnabled: false tlsMode: 0 - internaltlsEnabled : false session: ttl: 30 # ttl value when session granting a lease to register service retryTimes: 30 # retry times when session sending etcd requests diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index 796d86884ba79..c46dab7235126 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -25,9 +25,12 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" @@ -36,10 +39,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - testify_mock "github.com/stretchr/testify/mock" - "go.uber.org/zap" ) var mockErr = errors.New("mock grpc err") @@ -73,84 +72,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) } -func Test_InternalTLS(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - - ctx := context.Background() - client, err := NewClient(ctx) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockDC := mocks.NewMockDataCoordClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[datapb.DataCoordClient](t) - - // Set mock expectations - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(datapb.DataCoordClient) (interface{}, error)) (interface{}, error) { - return f(mockDC) - }) - // Sub-test for nil cert pool - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) // Simulate no cert pool - - mockDC.EXPECT().Flush(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.Flush(ctx, &datapb.FlushRequest{}) - assert.Error(t, err) // Check for an error - assert.Equal(t, ErrNoCertPool, err) // Check that it's the expected error - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "datacoord") - assert.NotNil(t, err) // Expect an error while creating cert pool - assert.Nil(t, cp) // Cert pool should be nil - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "datacoord") - assert.Nil(t, err) - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - - mockDC.ExpectedCalls = nil - mockDC.EXPECT().Flush(mock.Anything, mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.Flush(ctx, &datapb.FlushRequest{}) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("SuccessfulFlush", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "datacoord") - assert.NoError(t, err) - assert.NotNil(t, cp) - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - mockDC.ExpectedCalls = nil - mockDC.EXPECT().Flush(mock.Anything, mock.Anything).Return(&datapb.FlushResponse{ - Status: merr.Success(), - }, mockErr) - - _, err = client.Flush(ctx, &datapb.FlushRequest{}) - assert.NotNil(t, err) - }) -} - func Test_GetComponentStates(t *testing.T) { paramtable.Init() diff --git a/internal/distributed/datanode/client/client_test.go b/internal/distributed/datanode/client/client_test.go index a270054d896ff..03e4b64e74e62 100644 --- a/internal/distributed/datanode/client/client_test.go +++ b/internal/distributed/datanode/client/client_test.go @@ -19,16 +19,11 @@ package grpcdatanodeclient import ( "context" "testing" - "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" - testify_mock "github.com/stretchr/testify/mock" "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" - "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -130,89 +125,3 @@ func Test_NewClient(t *testing.T) { err = client.Close() assert.NoError(t, err) } - -func Test_InternalTLS(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - ctx := context.Background() - client, err := NewClient(ctx, "test", 1) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockDataNode := mocks.NewMockDataNodeClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[datapb.DataNodeClient](t) - - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(datapb.DataNodeClient) (interface{}, error)) (interface{}, error) { - return f(mockDataNode) - }) - - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) - - mockDataNode.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.GetComponentStates(ctx, nil) - assert.Error(t, err) - assert.Equal(t, ErrNoCertPool, err) - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "datanode") - assert.NotNil(t, err) - assert.Nil(t, cp) - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "datanode") - assert.Nil(t, err) - mockDataNode.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockDataNode.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.GetComponentStates(ctx, nil) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("TlsHandshakeSuccess", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "datanode") - assert.Nil(t, err) - mockDataNode.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockDataNode.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(&milvuspb.ComponentStates{}, nil) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - componentStates, err := client.GetComponentStates(ctx, nil) - assert.Nil(t, err) - assert.NotNil(t, componentStates) - assert.IsType(t, &milvuspb.ComponentStates{}, componentStates) - }) - - t.Run("ContextDeadlineExceeded", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - time.Sleep(20 * time.Millisecond) - - _, err := client.GetComponentStates(ctx, nil) - assert.ErrorIs(t, err, context.DeadlineExceeded) - }) -} diff --git a/internal/distributed/indexnode/client/client_test.go b/internal/distributed/indexnode/client/client_test.go index 042ef716ced4c..afa41c152e832 100644 --- a/internal/distributed/indexnode/client/client_test.go +++ b/internal/distributed/indexnode/client/client_test.go @@ -18,7 +18,6 @@ package grpcindexnodeclient import ( "context" - "errors" "math/rand" "os" "strings" @@ -29,17 +28,13 @@ import ( "github.com/stretchr/testify/mock" "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" - testify_mock "github.com/stretchr/testify/mock" ) func TestMain(m *testing.M) { @@ -182,89 +177,3 @@ func TestIndexNodeClient(t *testing.T) { err = client.Close() assert.NoError(t, err) } - -func Test_InternalTLS_IndexNode(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - ctx := context.Background() - client, err := NewClient(ctx, "test", 1, false) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockIndexNode := mocks.NewMockIndexNodeClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[indexpb.IndexNodeClient](t) - - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(indexpb.IndexNodeClient) (interface{}, error)) (interface{}, error) { - return f(mockIndexNode) - }) - - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) - - mockIndexNode.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.GetComponentStates(ctx, nil) - assert.Error(t, err) - assert.Equal(t, ErrNoCertPool, err) - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "indexnode") - assert.NotNil(t, err) - assert.Nil(t, cp) - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "indexnode") - assert.Nil(t, err) - mockIndexNode.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockIndexNode.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.GetComponentStates(ctx, nil) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("TlsHandshakeSuccess", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "indexnode") - assert.Nil(t, err) - mockIndexNode.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockIndexNode.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(&milvuspb.ComponentStates{}, nil) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - componentStates, err := client.GetComponentStates(ctx, nil) - assert.Nil(t, err) - assert.NotNil(t, componentStates) - assert.IsType(t, &milvuspb.ComponentStates{}, componentStates) - }) - - t.Run("ContextDeadlineExceeded", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - time.Sleep(20 * time.Millisecond) - - _, err := client.GetComponentStates(ctx, nil) - assert.ErrorIs(t, err, context.DeadlineExceeded) - }) -} diff --git a/internal/distributed/proxy/client/client_test.go b/internal/distributed/proxy/client/client_test.go index ae23a3ffad1b9..e43b02869cbf2 100644 --- a/internal/distributed/proxy/client/client_test.go +++ b/internal/distributed/proxy/client/client_test.go @@ -18,7 +18,6 @@ package grpcproxyclient import ( "context" - "errors" "testing" "time" @@ -26,13 +25,11 @@ import ( "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - testify_mock "github.com/stretchr/testify/mock" ) func Test_NewClient(t *testing.T) { @@ -52,91 +49,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) } -func Test_InternalTLS(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - ctx := context.Background() - client, err := NewClient(ctx, "test", 1) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockProxy := mocks.NewMockProxyClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[proxypb.ProxyClient](t) - - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(proxypb.ProxyClient) (interface{}, error)) (interface{}, error) { - return f(mockProxy) - }) - - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) - - mockProxy.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.GetComponentStates(ctx, nil) - assert.Error(t, err) - assert.Equal(t, ErrNoCertPool, err) - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "proxy") - assert.NotNil(t, err) - assert.Nil(t, cp) - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "proxy") - assert.Nil(t, err) - mockProxy.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockProxy.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.GetComponentStates(ctx, nil) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("TlsHandshakeSuccess", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "proxy") - assert.Nil(t, err) - mockProxy.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockProxy.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(&milvuspb.ComponentStates{}, nil) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - componentStates, err := client.GetComponentStates(ctx, nil) - assert.Nil(t, err) - assert.NotNil(t, componentStates) - assert.IsType(t, &milvuspb.ComponentStates{}, componentStates) - }) - - t.Run("ContextDeadlineExceeded", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - time.Sleep(20 * time.Millisecond) - - _, err := client.GetComponentStates(ctx, nil) - assert.ErrorIs(t, err, context.DeadlineExceeded) - }) -} func Test_GetComponentStates(t *testing.T) { paramtable.Init() diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index caa5179066189..0b14ed48b2fa0 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -29,15 +29,11 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" - "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" - testify_mock "github.com/stretchr/testify/mock" ) func TestMain(m *testing.M) { @@ -228,89 +224,3 @@ func Test_NewClient(t *testing.T) { err = client.Close() assert.NoError(t, err) } - -func Test_InternalTLS(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - ctx := context.Background() - client, err := NewClient(ctx) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockQC := mocks.NewMockQueryCoordClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[querypb.QueryCoordClient](t) - - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(querypb.QueryCoordClient) (interface{}, error)) (interface{}, error) { - return f(mockQC) - }) - - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) - - mockQC.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.GetComponentStates(ctx, nil) - assert.Error(t, err) - assert.Equal(t, ErrNoCertPool, err) - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "querycoord") - assert.NotNil(t, err) - assert.Nil(t, cp) - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "querycoord") - assert.Nil(t, err) - mockQC.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockQC.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.GetComponentStates(ctx, nil) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("TlsHandshakeSuccess", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "querycoord") - assert.Nil(t, err) - mockQC.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockQC.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(&milvuspb.ComponentStates{}, nil) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - componentStates, err := client.GetComponentStates(ctx, nil) - assert.Nil(t, err) - assert.NotNil(t, componentStates) - assert.IsType(t, &milvuspb.ComponentStates{}, componentStates) - }) - - t.Run("ContextDeadlineExceeded", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - time.Sleep(20 * time.Millisecond) - - _, err := client.GetComponentStates(ctx, nil) - assert.ErrorIs(t, err, context.DeadlineExceeded) - }) -} diff --git a/internal/distributed/querynode/client/client_test.go b/internal/distributed/querynode/client/client_test.go index b9874cc0847ca..e24a8b59a268e 100644 --- a/internal/distributed/querynode/client/client_test.go +++ b/internal/distributed/querynode/client/client_test.go @@ -19,18 +19,14 @@ package grpcquerynodeclient import ( "context" "testing" - "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" - "github.com/milvus-io/milvus/internal/mocks" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" - testify_mock "github.com/stretchr/testify/mock" - "google.golang.org/grpc" ) func Test_NewClient(t *testing.T) { @@ -165,89 +161,3 @@ func Test_NewClient(t *testing.T) { err = client.Close() assert.NoError(t, err) } - -func Test_InternalTLS(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - ctx := context.Background() - client, err := NewClient(ctx, "test", 1) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockQN := mocks.NewMockQueryNodeClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[querypb.QueryNodeClient](t) - - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(querypb.QueryNodeClient) (interface{}, error)) (interface{}, error) { - return f(mockQN) - }) - - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) - - mockQN.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.GetComponentStates(ctx, nil) - assert.Error(t, err) - assert.Equal(t, ErrNoCertPool, err) - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "querynode") - assert.NotNil(t, err) - assert.Nil(t, cp) - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "querynode") - assert.Nil(t, err) - mockQN.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockQN.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.GetComponentStates(ctx, nil) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("TlsHandshakeSuccess", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "querynode") - assert.Nil(t, err) - mockQN.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockQN.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(&milvuspb.ComponentStates{}, nil) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - componentStates, err := client.GetComponentStates(ctx, nil) - assert.Nil(t, err) - assert.NotNil(t, componentStates) - assert.IsType(t, &milvuspb.ComponentStates{}, componentStates) - }) - - t.Run("ContextDeadlineExceeded", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - time.Sleep(20 * time.Millisecond) - - _, err := client.GetComponentStates(ctx, nil) - assert.ErrorIs(t, err, context.DeadlineExceeded) - }) -} diff --git a/internal/distributed/rootcoord/client/client_test.go b/internal/distributed/rootcoord/client/client_test.go index 43582536be1ff..b70883dcf742f 100644 --- a/internal/distributed/rootcoord/client/client_test.go +++ b/internal/distributed/rootcoord/client/client_test.go @@ -25,18 +25,15 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/distributed/utils" - "github.com/milvus-io/milvus/internal/mocks" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" - testify_mock "github.com/stretchr/testify/mock" - "go.uber.org/zap" - "google.golang.org/grpc" ) func TestMain(m *testing.M) { @@ -489,89 +486,3 @@ func Test_NewClient(t *testing.T) { err = client.Close() assert.NoError(t, err) } - -func Test_InternalTLS(t *testing.T) { - paramtable.Init() - validPath := "../../../../configs/cert1/ca.pem" - ctx := context.Background() - client, err := NewClient(ctx) - assert.NoError(t, err) - assert.NotNil(t, client) - defer client.Close() - - mockRC := mocks.NewMockRootCoordClient(t) - mockGrpcClient := mocks.NewMockGrpcClient[rootcoordpb.RootCoordClient](t) - - mockGrpcClient.EXPECT().Close().Return(nil) - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockGrpcClient.EXPECT().ReCall(testify_mock.Anything, testify_mock.Anything).RunAndReturn(func(ctx context.Context, f func(rootcoordpb.RootCoordClient) (interface{}, error)) (interface{}, error) { - return f(mockRC) - }) - - t.Run("NoCertPool", func(t *testing.T) { - var ErrNoCertPool = errors.New("no cert pool") - mockGrpcClient.EXPECT().SetInternalTLSCertPool(testify_mock.Anything).Return().Once() - client.(*Client).grpcClient = mockGrpcClient - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(nil) - - mockRC.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, ErrNoCertPool) - - _, err := client.GetComponentStates(ctx, nil) - assert.Error(t, err) - assert.Equal(t, ErrNoCertPool, err) - }) - - // Sub-test for invalid certificate path - t.Run("InvalidCertPath", func(t *testing.T) { - invalidCAPath := "invalid/path/to/ca.pem" - cp, err := utils.CreateCertPoolforClient(invalidCAPath, "rootcoord") - assert.NotNil(t, err) - assert.Nil(t, cp) - }) - - // Sub-test for TLS handshake failure - t.Run("TlsHandshakeFailed", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "rootcoord") - assert.Nil(t, err) - mockRC.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockRC.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(nil, errors.New("TLS handshake failed")) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - _, err = client.GetComponentStates(ctx, nil) - assert.NotNil(t, err) - assert.EqualError(t, err, "TLS handshake failed") - }) - - t.Run("TlsHandshakeSuccess", func(t *testing.T) { - cp, err := utils.CreateCertPoolforClient(validPath, "rootcoord") - assert.Nil(t, err) - mockRC.ExpectedCalls = nil - - mockGrpcClient.EXPECT().SetInternalTLSCertPool(cp).Return().Once() - mockGrpcClient.EXPECT().GetNodeID().Return(1) - mockRC.EXPECT().GetComponentStates(testify_mock.Anything, testify_mock.Anything).Return(&milvuspb.ComponentStates{}, nil) - - client.(*Client).grpcClient.GetNodeID() - client.(*Client).grpcClient.SetInternalTLSCertPool(cp) - - componentStates, err := client.GetComponentStates(ctx, nil) - assert.Nil(t, err) - assert.NotNil(t, componentStates) - assert.IsType(t, &milvuspb.ComponentStates{}, componentStates) - }) - - t.Run("ContextDeadlineExceeded", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - time.Sleep(20 * time.Millisecond) - - _, err := client.GetComponentStates(ctx, nil) - assert.ErrorIs(t, err, context.DeadlineExceeded) - }) -} diff --git a/internal/distributed/utils/util.go b/internal/distributed/utils/util.go index d95f34ae9e00e..a61b6ff2b0905 100644 --- a/internal/distributed/utils/util.go +++ b/internal/distributed/utils/util.go @@ -2,10 +2,10 @@ package utils import ( "crypto/x509" - "errors" "os" "time" + "github.com/cockroachdb/errors" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials" diff --git a/pkg/util/paramtable/base_table.go b/pkg/util/paramtable/base_table.go index a7e038fc64fb8..3d3bbd38b4e6c 100644 --- a/pkg/util/paramtable/base_table.go +++ b/pkg/util/paramtable/base_table.go @@ -64,8 +64,9 @@ func globalConfigPrefixs() []string { return []string{"metastore", "localStorage", "etcd", "tikv", "minio", "pulsar", "kafka", "rocksmq", "log", "grpc", "common", "quotaAndLimits", "trace"} } -// support read "milvus.yaml", "default.yaml", "user.yaml" as this order. -// order: milvus.yaml < default.yaml < user.yaml, do not change the order below +// support read "milvus.yaml", "_test.yaml", "default.yaml", "user.yaml" as this order. +// order: milvus.yaml < _test.yaml < default.yaml < user.yaml, do not change the order below. +// Use _test.yaml only for test related purpose. var defaultYaml = []string{"milvus.yaml", "_test.yaml", "default.yaml", "user.yaml"} // BaseTable the basics of paramtable diff --git a/pkg/util/paramtable/grpc_param.go b/pkg/util/paramtable/grpc_param.go index eb53070d19ea2..53e394c07ccb8 100644 --- a/pkg/util/paramtable/grpc_param.go +++ b/pkg/util/paramtable/grpc_param.go @@ -547,7 +547,7 @@ func (p *InternalTLSConfig) Init(base *BaseTable) { p.InternalTLSEnabled = ParamItem{ Key: "common.security.internaltlsEnabled", Version: "2.0.0", - DefaultValue: "0", + DefaultValue: "false", Export: true, } p.InternalTLSEnabled.Init(base.mgr) diff --git a/pkg/util/paramtable/grpc_param_test.go b/pkg/util/paramtable/grpc_param_test.go index 5e598d5fd16c6..702c67d79c7a1 100644 --- a/pkg/util/paramtable/grpc_param_test.go +++ b/pkg/util/paramtable/grpc_param_test.go @@ -182,15 +182,15 @@ func TestGrpcClientParams(t *testing.T) { func TestInternalTLSParams(t *testing.T) { base := ComponentParam{} base.Init(NewBaseTable(SkipRemote(true))) - var internalTlsCfg InternalTLSConfig - internalTlsCfg.Init(base.baseTable) + var internalTLSCfg InternalTLSConfig + internalTLSCfg.Init(base.baseTable) - base.Save("common.security.internalTlsEnabled", "True") + base.Save("common.security.internalTlsEnabled", "true") base.Save("internaltls.serverPemPath", "/pem") base.Save("internaltls.serverKeyPath", "/key") base.Save("internaltls.caPemPath", "/ca") - assert.Equal(t, internalTlsCfg.InternalTLSEnabled.GetAsBool(), "True") - assert.Equal(t, internalTlsCfg.InternalTLSServerPemPath.GetValue(), "/pem") - assert.Equal(t, internalTlsCfg.InternalTLSServerKeyPath.GetValue(), "/key") - assert.Equal(t, internalTlsCfg.InternalTLSCaPemPath.GetValue(), "/ca") + assert.Equal(t, internalTLSCfg.InternalTLSEnabled.GetAsBool(), "true") + assert.Equal(t, internalTLSCfg.InternalTLSServerPemPath.GetValue(), "/pem") + assert.Equal(t, internalTLSCfg.InternalTLSServerKeyPath.GetValue(), "/key") + assert.Equal(t, internalTLSCfg.InternalTLSCaPemPath.GetValue(), "/ca") } diff --git a/tests/integration/internaltls/internaltls_test.go b/tests/integration/internaltls/internaltls_test.go index 15399f73e2981..9967051f02ad6 100644 --- a/tests/integration/internaltls/internaltls_test.go +++ b/tests/integration/internaltls/internaltls_test.go @@ -81,18 +81,16 @@ internaltls: caPemPath: ../../../configs/cert/ca.pem ` -// Path to the config file const configFilePath = "../../../configs/_test.yaml" // CreateConfigFile creates the YAML configuration file for tests func CreateConfigFile() { - // Create config directosry if it doesn't exist - // Write config content to user.yaml file - err := os.WriteFile(configFilePath, []byte(configContent), 0644) + // Write config content to _test.yaml file + err := os.WriteFile(configFilePath, []byte(configContent), 0o600) if err != nil { log.Error("Failed to create config file", zap.Error(err)) } - log.Info("config file created") + log.Info(fmt.Sprintf("Config file created: %s", configFilePath)) } func (s *InternaltlsTestSuit) SetupSuite() { @@ -344,7 +342,7 @@ func (s *InternaltlsTestSuit) TearDownSuite() { defer func() { err := os.Remove(configFilePath) if err != nil { - log.Error("failed to delete config file:", zap.Error(err)) + log.Error("Failed to delete config file:", zap.Error(err)) return } log.Info(fmt.Sprintf("Config file deleted: %s", configFilePath)) @@ -353,6 +351,5 @@ func (s *InternaltlsTestSuit) TearDownSuite() { } func TestInternalTLS(t *testing.T) { - log.Info("About to run...") suite.Run(t, new(InternaltlsTestSuit)) }