From bb66bf1b06f7153d4600ba09f7ff063042fb313f Mon Sep 17 00:00:00 2001 From: John Weldon Date: Tue, 2 Jul 2024 11:43:43 -0700 Subject: [PATCH] Add tlsfirst option to pass through to NATS connections (#208) --- README.md | 1 + cmd/root.go | 5 +++++ surveyor/surveyor.go | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/README.md b/README.md index 0125d8b..c36a6cf 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ Flags: --tlscacert string Client certificate CA on NATS connections. --tlscert string Client certificate file for NATS connections. --tlskey string Client private key for NATS connections. + --tlsfirst bool Whether to use TLS First connections. --user string NATS user name or token -v, --version version for nats-surveyor ``` diff --git a/cmd/root.go b/cmd/root.go index ceb761b..6712c1f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -196,6 +196,10 @@ func init() { rootCmd.Flags().String("tlscacert", "", "Client certificate CA on NATS connections.") _ = viper.BindPFlag("tlscacert", rootCmd.Flags().Lookup("tlscacert")) + // tlsfirst + rootCmd.Flags().Bool("tlsfirst", false, "Whether to use TLS First connections.") + _ = viper.BindPFlag("tlsfirst", rootCmd.Flags().Lookup("tlsfirst")) + // http-tlscert rootCmd.Flags().String("http-tlscert", "", "Server certificate file (Enables HTTPS).") _ = viper.BindPFlag("http-tlscert", rootCmd.Flags().Lookup("http-tlscert")) @@ -255,6 +259,7 @@ func getSurveyorOpts() *surveyor.Options { opts.CertFile = viper.GetString("tlscert") opts.KeyFile = viper.GetString("tlskey") opts.CaFile = viper.GetString("tlscacert") + opts.TLSFirst = viper.GetBool("tlsfirst") opts.HTTPCertFile = viper.GetString("http-tlscert") opts.HTTPKeyFile = viper.GetString("http-tlskey") opts.HTTPCaFile = viper.GetString("http-tlscacert") diff --git a/surveyor/surveyor.go b/surveyor/surveyor.go index b1d5c7f..277521c 100644 --- a/surveyor/surveyor.go +++ b/surveyor/surveyor.go @@ -69,6 +69,7 @@ type Options struct { CertFile string KeyFile string CaFile string + TLSFirst bool HTTPCertFile string HTTPKeyFile string HTTPCaFile string @@ -184,6 +185,9 @@ func newSurveyorConnPool(opts *Options, reconnectCtr *prometheus.CounterVec) *na }), nats.MaxReconnects(10240), ) + if opts.TLSFirst { + natsOpts = append(natsOpts, nats.TLSHandshakeFirst()) + } return newNatsConnPool(opts.Logger, natsDefaults, natsOpts) }