Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace natsContext with standard NATS Options #204

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 32 additions & 157 deletions surveyor/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,97 +2,14 @@ package surveyor

import (
"crypto/sha256"
"crypto/tls"
"encoding/json"
"fmt"
"os"
"sync"

"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)

type natsContext struct {
Name string `json:"name"`
URL string `json:"url"`
JWT string `json:"jwt"`
Seed string `json:"seed"`
Credentials string `json:"credential"`
Nkey string `json:"nkey"`
Token string `json:"token"`
Username string `json:"username"`
Password string `json:"password"`
TLSCA string `json:"tls_ca"`
TLSCert string `json:"tls_cert"`
TLSKey string `json:"tls_key"`

// only passed programmatically
NatsOpts []nats.Option `json:"-"`
}

func (c *natsContext) copy() *natsContext {
if c == nil {
return nil
}
cp := *c
return &cp
}

func (c *natsContext) hash() (string, error) {
b, err := json.Marshal(c)
if err != nil {
return "", fmt.Errorf("error marshaling context to json: %v", err)
}
if c.Nkey != "" {
fb, err := os.ReadFile(c.Nkey)
if err != nil {
return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err)
}
b = append(b, fb...)
}
if c.Credentials != "" {
fb, err := os.ReadFile(c.Credentials)
if err != nil {
return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err)
}
b = append(b, fb...)
}
if c.TLSCA != "" {
fb, err := os.ReadFile(c.TLSCA)
if err != nil {
return "", fmt.Errorf("error opening ca file %s: %v", c.TLSCA, err)
}
b = append(b, fb...)
}
if c.TLSCert != "" {
fb, err := os.ReadFile(c.TLSCert)
if err != nil {
return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err)
}
b = append(b, fb...)
}
if c.TLSKey != "" {
fb, err := os.ReadFile(c.TLSKey)
if err != nil {
return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err)
}
b = append(b, fb...)
}
hash := sha256.New()
hash.Write(b)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

type natsContextDefaults struct {
Name string
URL string
TLSCA string
TLSCert string
TLSKey string
TLSConfig *tls.Config
}

type pooledNatsConn struct {
nc *nats.Conn
cp *natsConnPool
Expand Down Expand Up @@ -121,15 +38,15 @@ type natsConnPool struct {
cache map[string]*pooledNatsConn
logger *logrus.Logger
group *singleflight.Group
natsDefaults *natsContextDefaults
natsDefaults []nats.Option
natsOpts []nats.Option
}

func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool {
func newNatsConnPool(logger *logrus.Logger, natsDefaults []nats.Option, natsOpts []nats.Option) *natsConnPool {
return &natsConnPool{
cache: map[string]*pooledNatsConn{},
group: &singleflight.Group{},
logger: logger,
group: &singleflight.Group{},
natsDefaults: natsDefaults,
natsOpts: natsOpts,
}
Expand All @@ -138,39 +55,15 @@ func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, n
const getPooledConnMaxTries = 10

// Get returns a *pooledNatsConn
func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
if cfg == nil {
return nil, fmt.Errorf("nats context must not be nil")
}

// copy cfg
cfg = cfg.copy()

// set defaults
if cfg.Name == "" {
cfg.Name = cp.natsDefaults.Name
}
if cfg.URL == "" {
cfg.URL = cp.natsDefaults.URL
}
if cfg.TLSCA == "" {
cfg.TLSCA = cp.natsDefaults.TLSCA
}
if cfg.TLSCert == "" {
cfg.TLSCert = cp.natsDefaults.TLSCert
}
if cfg.TLSKey == "" {
cfg.TLSKey = cp.natsDefaults.TLSKey
func (cp *natsConnPool) Get(opts []nats.Option) (*pooledNatsConn, error) {
if len(opts) == 0 {
return nil, fmt.Errorf("nats options must not be empty ")
}

// get hash
key, err := cfg.hash()
if err != nil {
return nil, err
}
key := cp.hash(opts)

for i := 0; i < getPooledConnMaxTries; i++ {
connection, err := cp.getPooledConn(key, cfg)
connection, err := cp.getPooledConn(key, opts)
if err != nil {
return nil, err
}
Expand All @@ -192,7 +85,7 @@ func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
}

// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count
func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) {
func (cp *natsConnPool) getPooledConn(key string, opts []nats.Option) (*pooledNatsConn, error) {
conn, err, _ := cp.group.Do(key, func() (interface{}, error) {
cp.Lock()
pooledConn, ok := cp.cache[key]
Expand All @@ -202,52 +95,16 @@ func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNats
}
cp.Unlock()

opts := append(cp.natsOpts, cfg.NatsOpts...)
opts = append(opts, func(options *nats.Options) error {
if cfg.Name != "" {
options.Name = cfg.Name
}
if cfg.Token != "" {
options.Token = cfg.Token
}
if cfg.Username != "" {
options.User = cfg.Username
}
if cfg.Password != "" {
options.Password = cfg.Password
}
return nil
})

if cfg.JWT != "" && cfg.Seed != "" {
opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed))
}

if cfg.Nkey != "" {
opt, err := nats.NkeyOptionFromSeed(cfg.Nkey)
if err != nil {
return nil, fmt.Errorf("unable to load nkey: %v", err)
}
opts = append(opts, opt)
connOpts := nats.GetDefaultOptions()
for _, o := range opts {
o(&connOpts)
}

if cfg.Credentials != "" {
opts = append(opts, nats.UserCredentials(cfg.Credentials))
}

if cfg.TLSCA != "" {
opts = append(opts, nats.RootCAs(cfg.TLSCA))
}

if cfg.TLSCert != "" && cfg.TLSKey != "" {
opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey))
}

nc, err := nats.Connect(cfg.URL, opts...)
nc, err := connOpts.Connect()
if err != nil {
return nil, err
}
cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr())
cp.logger.Infof("%s connected to NATS Deployment: %s", connOpts.Name, nc.ConnectedAddr())

connection := &pooledNatsConn{
nc: nc,
Expand All @@ -272,3 +129,21 @@ func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNats
}
return connection, nil
}

func (cp *natsConnPool) hash(opts []nats.Option) string {
cloneOpts := nats.GetDefaultOptions()

// Set opts
for _, o := range opts {
o(&cloneOpts)
}

ptrBytes := make([]byte, 0)
for _, f := range opts {
ptrBytes = append(ptrBytes, []byte(fmt.Sprintf("%p", f))...)
}
hash := sha256.New()
hash.Write(ptrBytes)

return fmt.Sprintf("%x", hash.Sum(nil))
}
38 changes: 25 additions & 13 deletions surveyor/conn_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,30 @@ func TestConnPool(t *testing.T) {

s := natsservertest.RunRandClientPortServer()
defer s.Shutdown()
o1 := &natsContext{
Name: "Client 1",

clientOne := func(o *nats.Options) error {
o.Name = "Client 1"
return nil
}

o1 := []nats.Option{
clientOne,
}
o2 := &natsContext{
Name: "Client 1",
o2 := []nats.Option{
clientOne,
}
o3 := &natsContext{
Name: "Client 2",
o3 := []nats.Option{
func(o *nats.Options) error {
o.Name = "Client 2"
return nil
},
}

natsDefaults := &natsContextDefaults{
URL: s.ClientURL(),
natsDefaults := []nats.Option{
func(o *nats.Options) error {
o.Url = s.ClientURL()
return nil
},
}
natsOptions := []nats.Option{
nats.MaxReconnects(10240),
Expand All @@ -40,15 +52,15 @@ func TestConnPool(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(3)
go func() {
c1, c1e = cp.Get(o1)
c1, c1e = cp.Get(append(natsDefaults, append(natsOptions, o1...)...))
wg.Done()
}()
go func() {
c2, c2e = cp.Get(o2)
c2, c2e = cp.Get(append(natsDefaults, append(natsOptions, o2...)...))
wg.Done()
}()
go func() {
c3, c3e = cp.Get(o3)
c3, c3e = cp.Get(append(natsDefaults, append(natsOptions, o3...)...))
wg.Done()
}()
wg.Wait()
Expand All @@ -69,7 +81,7 @@ func TestConnPool(t *testing.T) {
assert.False(c2.nc.IsClosed())
assert.True(c3.nc.IsClosed())

c4, c4e := cp.Get(o1)
c4, c4e := cp.Get(append(natsDefaults, append(natsOptions, o1...)...))
if assert.NoError(c4e) {
assert.Same(c2, c4)
}
Expand All @@ -81,7 +93,7 @@ func TestConnPool(t *testing.T) {
assert.True(c2.nc.IsClosed())
assert.True(c4.nc.IsClosed())

c5, c5e := cp.Get(o1)
c5, c5e := cp.Get(append(natsDefaults, append(natsOptions, o1...)...))
if assert.NoError(c5e) {
assert.NotSame(c1, c5)
}
Expand Down
Loading
Loading