From 253d7bf6bd08cafcb0e484736f84a2387e41eb62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Bires?= Date: Fri, 11 Mar 2022 15:18:28 -0300 Subject: [PATCH] message client test --- client.go | 83 ++++++++++++++++++++++++++++++++++++++--------- client_channel.go | 28 ++++++++-------- client_test.go | 51 +++++++++++++++++++++++++++++ handler.go | 2 +- server.go | 6 ++-- 5 files changed, 137 insertions(+), 33 deletions(-) create mode 100644 client_test.go diff --git a/client.go b/client.go index c7508dd..3bc8182 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package lime import ( "context" "crypto/tls" + "errors" "fmt" "github.com/google/uuid" "log" @@ -12,7 +13,6 @@ import ( "os" "reflect" "runtime" - "sync" "time" ) @@ -20,7 +20,10 @@ type Client struct { config *ClientConfig channel *ClientChannel mux *EnvelopeMux - mu sync.Mutex + lock chan struct{} + cancel context.CancelFunc // cancel stops the channel listener goroutine + done chan bool // done is used by the listener goroutine to signal its end + } func NewClient(config *ClientConfig, mux *EnvelopeMux) *Client { @@ -30,17 +33,23 @@ func NewClient(config *ClientConfig, mux *EnvelopeMux) *Client { if mux == nil || reflect.ValueOf(mux).IsNil() { panic("nil mux") } - return &Client{config: config, mux: mux} + c := &Client{config: config, mux: mux, lock: make(chan struct{}, 1)} + c.startListener() + return c } func (c *Client) Close() error { - c.mu.Lock() - defer c.mu.Unlock() + c.lock <- struct{}{} + defer func() { + <-c.lock + }() if c.channel == nil { return nil } + c.stopListener() + if c.channel.Established() { // Try to close the session gracefully ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*5) @@ -50,7 +59,7 @@ func (c *Client) Close() error { return err } - err := c.channel.transport.Close() + err := c.channel.Close() c.channel = nil return err } @@ -96,8 +105,17 @@ func (c *Client) getOrBuildChannel(ctx context.Context) (*ClientChannel, error) return c.channel, nil } - c.mu.Lock() - defer c.mu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case c.lock <- struct{}{}: + break + } + + defer func() { + <-c.lock + }() + if c.channelOK() { return c.channel, nil } @@ -112,6 +130,7 @@ func (c *Client) getOrBuildChannel(ctx context.Context) (*ClientChannel, error) // calling close just to release resources. _ = c.channel.Close() } + c.channel = channel return channel, nil } @@ -122,7 +141,41 @@ func (c *Client) getOrBuildChannel(ctx context.Context) (*ClientChannel, error) count++ } - return nil, fmt.Errorf("getOrBuildChannel: %w", ctx.Err()) + return nil, fmt.Errorf("client: getOrBuildChannel: %w", ctx.Err()) +} + +func (c *Client) startListener() { + ctx, cancel := context.WithCancel(context.Background()) + c.cancel = cancel + c.done = make(chan bool) + + go func() { + defer close(c.done) + + for ctx.Err() == nil { + channel, err := c.getOrBuildChannel(ctx) + if err != nil { + log.Printf("client: listen: %v", err) + continue + } + + if err := c.mux.ListenClient(ctx, channel); err != nil { + if errors.Is(err, context.Canceled) { + // stopListener has been called + continue + } + log.Printf("client: listen: %v", err) + } + } + }() +} + +func (c *Client) stopListener() { + if c.cancel != nil { + c.cancel() + <-c.done + c.cancel = nil + } } func (c *Client) buildChannel(ctx context.Context) (*ClientChannel, error) { @@ -183,16 +236,16 @@ func NewClientConfig() *ClientConfig { Port: 55321, }, nil) }, - CompSelector: func(compressions []SessionCompression) SessionCompression { - return compressions[0] + CompSelector: func(options []SessionCompression) SessionCompression { + return options[0] }, - EncryptSelector: func(encryptions []SessionEncryption) SessionEncryption { - if contains(encryptions, SessionEncryptionTLS) { + EncryptSelector: func(options []SessionEncryption) SessionEncryption { + if contains(options, SessionEncryptionTLS) { return SessionEncryptionTLS } - return encryptions[0] + return options[0] }, - Authenticator: func(schemes []AuthenticationScheme, authentication Authentication) Authentication { + Authenticator: func(schemes []AuthenticationScheme, _ Authentication) Authentication { if contains(schemes, AuthenticationSchemeGuest) { return &GuestAuthentication{} } diff --git a/client_channel.go b/client_channel.go index 547d31f..870450f 100644 --- a/client_channel.go +++ b/client_channel.go @@ -19,7 +19,7 @@ func NewClientChannel(t Transport, bufferSize int) *ClientChannel { func (c *ClientChannel) receiveSessionFromServer(ctx context.Context) (*Session, error) { ses, err := c.receiveSession(ctx) if err != nil { - return nil, fmt.Errorf("receive session failed: %w", err) + return nil, fmt.Errorf("receive session: %w", err) } if ses.State == SessionStateEstablished { @@ -130,12 +130,12 @@ func (c *ClientChannel) sendFinishingSession(ctx context.Context) error { } // CompressionSelector defines a function for selecting the compression for a session. -type CompressionSelector func([]SessionCompression) SessionCompression +type CompressionSelector func(options []SessionCompression) SessionCompression // EncryptionSelector defines a function for selecting the encryption for a session. -type EncryptionSelector func([]SessionEncryption) SessionEncryption +type EncryptionSelector func(options []SessionEncryption) SessionEncryption -type Authenticator func([]AuthenticationScheme, Authentication) Authentication +type Authenticator func(schemes []AuthenticationScheme, roundTrip Authentication) Authentication // EstablishSession performs the client session negotiation and authentication handshake. func (c *ClientChannel) EstablishSession( @@ -156,17 +156,17 @@ func (c *ClientChannel) EstablishSession( ses, err := c.startNewSession(ctx) if err != nil { - return nil, fmt.Errorf("error establishing the session: %w", err) + return nil, fmt.Errorf("establish session: %w", err) } // Session negotiation if ses.State == SessionStateNegotiating { if compSelector == nil { - panic("the compression selector should not be nil") + panic("nil compression selector") } if encryptSelector == nil { - panic("the encryption selector should not be nil") + panic("nil encrypt selector") } // Select options @@ -175,20 +175,20 @@ func (c *ClientChannel) EstablishSession( compSelector(ses.CompressionOptions), encryptSelector(ses.EncryptionOptions)) if err != nil { - return nil, fmt.Errorf("error establishing the session: %w", err) + return nil, fmt.Errorf("establish session: %w", err) } if ses.State == SessionStateNegotiating { if ses.Compression != "" && ses.Compression != c.transport.Compression() { err = c.transport.SetCompression(ctx, ses.Compression) if err != nil { - return nil, fmt.Errorf("error setting the session compression: %w", err) + return nil, fmt.Errorf("establish session: set compression: %w", err) } } if ses.Encryption != "" && ses.Encryption != c.transport.Encryption() { err = c.transport.SetEncryption(ctx, ses.Encryption) if err != nil { - return nil, fmt.Errorf("error setting the session encryption: %w", err) + return nil, fmt.Errorf("establish session: set encryption: %w", err) } } } @@ -196,7 +196,7 @@ func (c *ClientChannel) EstablishSession( // Await for authentication options ses, err = c.receiveSessionFromServer(ctx) if err != nil { - return nil, fmt.Errorf("error establishing the session: %w", err) + return nil, fmt.Errorf("establish session: %w", err) } } @@ -211,7 +211,7 @@ func (c *ClientChannel) EstablishSession( instance, ) if err != nil { - return nil, fmt.Errorf("error establishing the session: %w", err) + return nil, fmt.Errorf("establish session: %w", err) } roundTrip = ses.Authentication } @@ -222,12 +222,12 @@ func (c *ClientChannel) EstablishSession( // FinishSession performs the session finishing handshake. func (c *ClientChannel) FinishSession(ctx context.Context) (*Session, error) { if err := c.sendFinishingSession(ctx); err != nil { - return nil, fmt.Errorf("error sending the finishing session: %w", err) + return nil, fmt.Errorf("finish session: %w", err) } ses, err := c.receiveSessionFromServer(ctx) if err != nil { - return nil, fmt.Errorf("error receiving the finished the session: %w", err) + return nil, fmt.Errorf("finish session: %w", err) } return ses, nil diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..5d86a8f --- /dev/null +++ b/client_test.go @@ -0,0 +1,51 @@ +package lime + +import ( + "context" + "errors" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "log" + "testing" + "time" +) + +func TestClient_NewClient_Message(t *testing.T) { + // Arrange + defer goleak.VerifyNone(t) + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + addr1 := InProcessAddr("localhost") + msgChan := make(chan *Message, 1) + server := NewServerBuilder(). + ListenInProcess(addr1). + MessagesHandlerFunc( + func(ctx context.Context, msg *Message, s Sender) error { + msgChan <- msg + return nil + }). + Build() + defer silentClose(server) + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, ErrServerClosed) { + log.Println(err) + } + }() + config := NewClientConfig() + config.NewTransport = func(ctx context.Context) (Transport, error) { + return DialInProcess(addr1, 1) + } + mux := &EnvelopeMux{} + client := NewClient(config, mux) + msg := createMessage() + + // Act + err := client.SendMessage(ctx, msg) + + // Assert + assert.NoError(t, err) + rcvMsg := <-msgChan + assert.Equal(t, msg, rcvMsg) + err = client.Close() + assert.NoError(t, err) +} diff --git a/handler.go b/handler.go index b139466..1fe0b43 100644 --- a/handler.go +++ b/handler.go @@ -62,7 +62,7 @@ func (m *EnvelopeMux) listen(ctx context.Context, c *channel) error { } } } - return nil + return ctx.Err() } func (m *EnvelopeMux) handleMessage(ctx context.Context, msg *Message, s Sender) error { diff --git a/server.go b/server.go index 5954bd4..7edc8e8 100644 --- a/server.go +++ b/server.go @@ -303,7 +303,7 @@ func (b *ServerBuilder) ListenWebsocket(addr net.TCPAddr, config *WebsocketConfi func (b *ServerBuilder) ListenInProcess(addr InProcessAddr) *ServerBuilder { listener := NewInProcessTransportListener(addr) - b.listeners = append(b.listeners, NewBoundListener(listener, &addr)) + b.listeners = append(b.listeners, NewBoundListener(listener, addr)) return b } @@ -408,8 +408,8 @@ func NewBoundListener(listener TransportListener, addr net.Addr) BoundListener { if listener == nil || reflect.ValueOf(listener).IsNil() { panic("nil Listener") } - if addr == nil || reflect.ValueOf(addr).IsNil() { - panic("nil Addr") + if addr == nil || reflect.ValueOf(addr).IsZero() { + panic("zero addr value") } return BoundListener{ Listener: listener,