Skip to content

Commit

Permalink
Refactor nats connection options
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelattwood committed Jun 13, 2024
1 parent 91df8f4 commit b92022a
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 222 deletions.
189 changes: 32 additions & 157 deletions surveyor/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,97 +2,14 @@ package surveyor

import (
"crypto/sha256"
"crypto/tls"
"encoding/json"
"fmt"
"os"
"sync"

"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)

type natsContext struct {
Name string `json:"name"`
URL string `json:"url"`
JWT string `json:"jwt"`
Seed string `json:"seed"`
Credentials string `json:"credential"`
Nkey string `json:"nkey"`
Token string `json:"token"`
Username string `json:"username"`
Password string `json:"password"`
TLSCA string `json:"tls_ca"`
TLSCert string `json:"tls_cert"`
TLSKey string `json:"tls_key"`

// only passed programmatically
NatsOpts []nats.Option `json:"-"`
}

func (c *natsContext) copy() *natsContext {
if c == nil {
return nil
}
cp := *c
return &cp
}

func (c *natsContext) hash() (string, error) {
b, err := json.Marshal(c)
if err != nil {
return "", fmt.Errorf("error marshaling context to json: %v", err)
}
if c.Nkey != "" {
fb, err := os.ReadFile(c.Nkey)
if err != nil {
return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err)
}
b = append(b, fb...)
}
if c.Credentials != "" {
fb, err := os.ReadFile(c.Credentials)
if err != nil {
return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err)
}
b = append(b, fb...)
}
if c.TLSCA != "" {
fb, err := os.ReadFile(c.TLSCA)
if err != nil {
return "", fmt.Errorf("error opening ca file %s: %v", c.TLSCA, err)
}
b = append(b, fb...)
}
if c.TLSCert != "" {
fb, err := os.ReadFile(c.TLSCert)
if err != nil {
return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err)
}
b = append(b, fb...)
}
if c.TLSKey != "" {
fb, err := os.ReadFile(c.TLSKey)
if err != nil {
return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err)
}
b = append(b, fb...)
}
hash := sha256.New()
hash.Write(b)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

type natsContextDefaults struct {
Name string
URL string
TLSCA string
TLSCert string
TLSKey string
TLSConfig *tls.Config
}

type pooledNatsConn struct {
nc *nats.Conn
cp *natsConnPool
Expand Down Expand Up @@ -121,15 +38,15 @@ type natsConnPool struct {
cache map[string]*pooledNatsConn
logger *logrus.Logger
group *singleflight.Group
natsDefaults *natsContextDefaults
natsDefaults []nats.Option
natsOpts []nats.Option
}

func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool {
func newNatsConnPool(logger *logrus.Logger, natsDefaults []nats.Option, natsOpts []nats.Option) *natsConnPool {
return &natsConnPool{
cache: map[string]*pooledNatsConn{},
group: &singleflight.Group{},
logger: logger,
group: &singleflight.Group{},
natsDefaults: natsDefaults,
natsOpts: natsOpts,
}
Expand All @@ -138,39 +55,15 @@ func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, n
const getPooledConnMaxTries = 10

// Get returns a *pooledNatsConn
func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
if cfg == nil {
return nil, fmt.Errorf("nats context must not be nil")
}

// copy cfg
cfg = cfg.copy()

// set defaults
if cfg.Name == "" {
cfg.Name = cp.natsDefaults.Name
}
if cfg.URL == "" {
cfg.URL = cp.natsDefaults.URL
}
if cfg.TLSCA == "" {
cfg.TLSCA = cp.natsDefaults.TLSCA
}
if cfg.TLSCert == "" {
cfg.TLSCert = cp.natsDefaults.TLSCert
}
if cfg.TLSKey == "" {
cfg.TLSKey = cp.natsDefaults.TLSKey
func (cp *natsConnPool) Get(opts []nats.Option) (*pooledNatsConn, error) {
if len(opts) == 0 {
return nil, fmt.Errorf("nats options must not be empty ")
}

// get hash
key, err := cfg.hash()
if err != nil {
return nil, err
}
key := cp.hash(opts)

for i := 0; i < getPooledConnMaxTries; i++ {
connection, err := cp.getPooledConn(key, cfg)
connection, err := cp.getPooledConn(key, opts)
if err != nil {
return nil, err
}
Expand All @@ -192,7 +85,7 @@ func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
}

// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count
func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) {
func (cp *natsConnPool) getPooledConn(key string, opts []nats.Option) (*pooledNatsConn, error) {
conn, err, _ := cp.group.Do(key, func() (interface{}, error) {
cp.Lock()
pooledConn, ok := cp.cache[key]
Expand All @@ -202,52 +95,16 @@ func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNats
}
cp.Unlock()

opts := append(cp.natsOpts, cfg.NatsOpts...)
opts = append(opts, func(options *nats.Options) error {
if cfg.Name != "" {
options.Name = cfg.Name
}
if cfg.Token != "" {
options.Token = cfg.Token
}
if cfg.Username != "" {
options.User = cfg.Username
}
if cfg.Password != "" {
options.Password = cfg.Password
}
return nil
})

if cfg.JWT != "" && cfg.Seed != "" {
opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed))
}

if cfg.Nkey != "" {
opt, err := nats.NkeyOptionFromSeed(cfg.Nkey)
if err != nil {
return nil, fmt.Errorf("unable to load nkey: %v", err)
}
opts = append(opts, opt)
connOpts := nats.GetDefaultOptions()
for _, o := range opts {
o(&connOpts)
}

if cfg.Credentials != "" {
opts = append(opts, nats.UserCredentials(cfg.Credentials))
}

if cfg.TLSCA != "" {
opts = append(opts, nats.RootCAs(cfg.TLSCA))
}

if cfg.TLSCert != "" && cfg.TLSKey != "" {
opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey))
}

nc, err := nats.Connect(cfg.URL, opts...)
nc, err := connOpts.Connect()
if err != nil {
return nil, err
}
cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr())
cp.logger.Infof("%s connected to NATS Deployment: %s", connOpts.Name, nc.ConnectedAddr())

connection := &pooledNatsConn{
nc: nc,
Expand All @@ -272,3 +129,21 @@ func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNats
}
return connection, nil
}

func (cp *natsConnPool) hash(opts []nats.Option) string {
cloneOpts := nats.GetDefaultOptions()

// Set opts
for _, o := range opts {
o(&cloneOpts)
}

ptrBytes := make([]byte, 0)
for _, f := range opts {
ptrBytes = append(ptrBytes, []byte(fmt.Sprintf("%p", f))...)
}
hash := sha256.New()
hash.Write(ptrBytes)

return fmt.Sprintf("%x", hash.Sum(nil))
}
28 changes: 20 additions & 8 deletions surveyor/conn_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,30 @@ func TestConnPool(t *testing.T) {

s := natsservertest.RunRandClientPortServer()
defer s.Shutdown()
o1 := &natsContext{
Name: "Client 1",

clientOne := func(o *nats.Options) error {
o.Name = "Client 1"
return nil
}

o1 := []nats.Option{
clientOne,
}
o2 := &natsContext{
Name: "Client 1",
o2 := []nats.Option{
clientOne,
}
o3 := &natsContext{
Name: "Client 2",
o3 := []nats.Option{
func(o *nats.Options) error {
o.Name = "Client 2"
return nil
},
}

natsDefaults := &natsContextDefaults{
URL: s.ClientURL(),
natsDefaults := []nats.Option{
func(o *nats.Options) error {
o.Url = s.ClientURL()
return nil
},
}
natsOptions := []nats.Option{
nats.MaxReconnects(10240),
Expand Down
45 changes: 26 additions & 19 deletions surveyor/jetstream_advisories.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,22 +379,29 @@ func newJetStreamAdvisoryListener(config *JSAdvisoryConfig, cp *natsConnPool, lo
}, nil
}

func (o *jsAdvisoryListener) natsContext() *natsContext {
natsCtx := &natsContext{
JWT: o.config.JWT,
Seed: o.config.Seed,
Credentials: o.config.Credentials,
Nkey: o.config.Nkey,
Token: o.config.Token,
Username: o.config.Username,
Password: o.config.Password,
TLSCA: o.config.TLSCA,
TLSCert: o.config.TLSCert,
TLSKey: o.config.TLSKey,
NatsOpts: o.config.NatsOpts,
}

return natsCtx
func (o *jsAdvisoryListener) connOpts() []nats.Option {
opts := append(make([]nats.Option, 0), o.cp.natsDefaults...)

if o.config.Username != "" && o.config.Password != "" {
opts = append(opts, nats.UserInfo(o.config.Username, o.config.Password))
}
if o.config.Token != "" {
opts = append(opts, nats.Token(o.config.Token))
}
if o.config.Credentials != "" {
opts = append(opts, nats.UserCredentials(o.config.Credentials))
}
if o.config.JWT != "" && o.config.Seed != "" {
opts = append(opts, nats.UserJWTAndSeed(o.config.JWT, o.config.Seed))
}
if o.config.TLSCert != "" && o.config.TLSKey != "" {
opts = append(opts, nats.ClientCert(o.config.TLSCert, o.config.TLSKey))
}
if o.config.TLSCA != "" {
opts = append(opts, nats.RootCAs(o.config.TLSCA))
}

return append(opts, o.config.NatsOpts...)
}

// Start starts listening for JetStream advisories
Expand All @@ -406,7 +413,7 @@ func (o *jsAdvisoryListener) Start() error {
return nil
}

pc, err := o.cp.Get(o.natsContext())
pc, err := o.cp.Get(o.connOpts())
if err != nil {
return fmt.Errorf("nats connection failed for id: %s, account name: %s, error: %v", o.config.ID, o.config.AccountName, err)
}
Expand All @@ -417,13 +424,13 @@ func (o *jsAdvisoryListener) Start() error {
advisorySubject = o.config.ExternalAccountConfig.AdvisorySubject
}

subAdvisory, err := pc.nc.Subscribe(metricsSubject, o.advisoryHandler)
subAdvisory, err := pc.nc.Subscribe(advisorySubject, o.advisoryHandler)
if err != nil {
pc.ReturnToPool()
return fmt.Errorf("could not subscribe to JetStream advisory for id: %s, account name: %s, topic: %s, error: %v", o.config.ID, o.config.AccountName, api.JSAdvisoryPrefix, err)
}

subMetric, err := pc.nc.Subscribe(advisorySubject, o.advisoryHandler)
subMetric, err := pc.nc.Subscribe(metricsSubject, o.advisoryHandler)
if err != nil {
_ = subAdvisory.Unsubscribe()
pc.ReturnToPool()
Expand Down
Loading

0 comments on commit b92022a

Please sign in to comment.