Skip to content

Commit

Permalink
add config field to set internal tls sni
Browse files Browse the repository at this point in the history
Signed-off-by: haorenfsa <[email protected]>
  • Loading branch information
haorenfsa committed Dec 2, 2024
1 parent cfe5613 commit d4c4caa
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 10 deletions.
3 changes: 2 additions & 1 deletion configs/milvus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -784,11 +784,12 @@ tls:
serverKeyPath: configs/cert/server.key
caPemPath: configs/cert/ca.pem

# Configure the node-tls enable.
# Configure the internal tls
internaltls:
serverPemPath: configs/cert/server.pem
serverKeyPath: configs/cert/server.key
caPemPath: configs/cert/ca.pem
sni: "localhost" # The server name indication (SNI) for internal TLS, should be the same as the name provided by the certificates ref: https://en.wikipedia.org/wiki/Server_Name_Indication

common:
defaultPartitionName: _default # Name of the default partition when a collection is created
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/datacoord/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func NewClient(ctx context.Context) (types.DataCoordClient, error) {
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/datanode/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, serverID int64) (types.DataNode
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/indexnode/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool)
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/proxy/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/querycoord/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func NewClient(ctx context.Context) (types.QueryCoordClient, error) {
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/querynode/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/rootcoord/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func NewClient(ctx context.Context) (types.RootCoordClient, error) {
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
21 changes: 16 additions & 5 deletions internal/util/grpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ type ClientBase[T interface {
newGrpcClient func(cc *grpc.ClientConn) T

// grpcClient T
grpcClient *clientConnWrapper[T]
encryption bool
cpInternalTLS *x509.CertPool
addr atomic.String
grpcClient *clientConnWrapper[T]
encryption bool
cpInternalTLS *x509.CertPool
addr atomic.String
internalTLSServerName string

// conn *grpc.ClientConn
grpcClientMtx sync.RWMutex
role string
Expand Down Expand Up @@ -194,6 +196,10 @@ func (c *ClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) {
c.cpInternalTLS = cp
}

func (c *ClientBase[T]) SetInternalTLSServerName(cp string) {
c.internalTLSServerName = cp
}

// SetNewGrpcClientFunc sets newGrpcClient of client
func (c *ClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) {
c.newGrpcClient = f
Expand Down Expand Up @@ -269,7 +275,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
dialContext,
addr,
// #nosec G402
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: c.cpInternalTLS})),
grpc.WithTransportCredentials(credentials.NewTLS(
&tls.Config{
RootCAs: c.cpInternalTLS,
ServerName: c.internalTLSServerName,
},
)),
grpc.WithBlock(),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize),
Expand Down
16 changes: 12 additions & 4 deletions pkg/util/paramtable/grpc_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,35 +541,43 @@ type InternalTLSConfig struct {
InternalTLSServerPemPath ParamItem `refreshable:"false"`
InternalTLSServerKeyPath ParamItem `refreshable:"false"`
InternalTLSCaPemPath ParamItem `refreshable:"false"`
InternalTLSSNI ParamItem `refreshable:"false"`
}

func (p *InternalTLSConfig) Init(base *BaseTable) {
p.InternalTLSEnabled = ParamItem{
Key: "common.security.internaltlsEnabled",
Version: "2.0.0",
Version: "2.5.0",
DefaultValue: "false",
Export: true,
}
p.InternalTLSEnabled.Init(base.mgr)

p.InternalTLSServerPemPath = ParamItem{
Key: "internaltls.serverPemPath",
Version: "2.0.0",
Version: "2.5.0",
Export: true,
}
p.InternalTLSServerPemPath.Init(base.mgr)

p.InternalTLSServerKeyPath = ParamItem{
Key: "internaltls.serverKeyPath",
Version: "2.0.0",
Version: "2.5.0",
Export: true,
}
p.InternalTLSServerKeyPath.Init(base.mgr)

p.InternalTLSCaPemPath = ParamItem{
Key: "internaltls.caPemPath",
Version: "2.0.0",
Version: "2.5.0",
Export: true,
}
p.InternalTLSCaPemPath.Init(base.mgr)

p.InternalTLSSNI = ParamItem{
Key: "internaltls.sni",
Version: "2.5.0",
Export: true,
}
p.InternalTLSSNI.Init(base.mgr)
}
2 changes: 2 additions & 0 deletions pkg/util/paramtable/grpc_param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,10 @@ func TestInternalTLSParams(t *testing.T) {
base.Save("internaltls.serverPemPath", "/pem")
base.Save("internaltls.serverKeyPath", "/key")
base.Save("internaltls.caPemPath", "/ca")
base.Save("internaltls.caPemPath", "localhost")
assert.Equal(t, internalTLSCfg.InternalTLSEnabled.GetAsBool(), true)
assert.Equal(t, internalTLSCfg.InternalTLSServerPemPath.GetValue(), "/pem")
assert.Equal(t, internalTLSCfg.InternalTLSServerKeyPath.GetValue(), "/key")
assert.Equal(t, internalTLSCfg.InternalTLSCaPemPath.GetValue(), "/ca")
assert.Equal(t, internalTLSCfg.InternalTLSSNI.GetValue(), "localhost")
}

0 comments on commit d4c4caa

Please sign in to comment.