Skip to content

Commit

Permalink
message client test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrebires committed Mar 11, 2022
1 parent bb48da9 commit 253d7bf
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 33 deletions.
83 changes: 68 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package lime
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/google/uuid"
"log"
Expand All @@ -12,15 +13,17 @@ import (
"os"
"reflect"
"runtime"
"sync"
"time"
)

type Client struct {
config *ClientConfig
channel *ClientChannel
mux *EnvelopeMux
mu sync.Mutex
lock chan struct{}
cancel context.CancelFunc // cancel stops the channel listener goroutine
done chan bool // done is used by the listener goroutine to signal its end

}

func NewClient(config *ClientConfig, mux *EnvelopeMux) *Client {
Expand All @@ -30,17 +33,23 @@ func NewClient(config *ClientConfig, mux *EnvelopeMux) *Client {
if mux == nil || reflect.ValueOf(mux).IsNil() {
panic("nil mux")
}
return &Client{config: config, mux: mux}
c := &Client{config: config, mux: mux, lock: make(chan struct{}, 1)}
c.startListener()
return c
}

func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.lock <- struct{}{}
defer func() {
<-c.lock
}()

if c.channel == nil {
return nil
}

c.stopListener()

if c.channel.Established() {
// Try to close the session gracefully
ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*5)
Expand All @@ -50,7 +59,7 @@ func (c *Client) Close() error {
return err
}

err := c.channel.transport.Close()
err := c.channel.Close()
c.channel = nil
return err
}
Expand Down Expand Up @@ -96,8 +105,17 @@ func (c *Client) getOrBuildChannel(ctx context.Context) (*ClientChannel, error)
return c.channel, nil
}

c.mu.Lock()
defer c.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case c.lock <- struct{}{}:
break
}

defer func() {
<-c.lock
}()

if c.channelOK() {
return c.channel, nil
}
Expand All @@ -112,6 +130,7 @@ func (c *Client) getOrBuildChannel(ctx context.Context) (*ClientChannel, error)
// calling close just to release resources.
_ = c.channel.Close()
}

c.channel = channel
return channel, nil
}
Expand All @@ -122,7 +141,41 @@ func (c *Client) getOrBuildChannel(ctx context.Context) (*ClientChannel, error)
count++
}

return nil, fmt.Errorf("getOrBuildChannel: %w", ctx.Err())
return nil, fmt.Errorf("client: getOrBuildChannel: %w", ctx.Err())
}

func (c *Client) startListener() {
ctx, cancel := context.WithCancel(context.Background())
c.cancel = cancel
c.done = make(chan bool)

go func() {
defer close(c.done)

for ctx.Err() == nil {
channel, err := c.getOrBuildChannel(ctx)
if err != nil {
log.Printf("client: listen: %v", err)
continue
}

if err := c.mux.ListenClient(ctx, channel); err != nil {
if errors.Is(err, context.Canceled) {
// stopListener has been called
continue
}
log.Printf("client: listen: %v", err)
}
}
}()
}

func (c *Client) stopListener() {
if c.cancel != nil {
c.cancel()
<-c.done
c.cancel = nil
}
}

func (c *Client) buildChannel(ctx context.Context) (*ClientChannel, error) {
Expand Down Expand Up @@ -183,16 +236,16 @@ func NewClientConfig() *ClientConfig {
Port: 55321,
}, nil)
},
CompSelector: func(compressions []SessionCompression) SessionCompression {
return compressions[0]
CompSelector: func(options []SessionCompression) SessionCompression {
return options[0]
},
EncryptSelector: func(encryptions []SessionEncryption) SessionEncryption {
if contains(encryptions, SessionEncryptionTLS) {
EncryptSelector: func(options []SessionEncryption) SessionEncryption {
if contains(options, SessionEncryptionTLS) {
return SessionEncryptionTLS
}
return encryptions[0]
return options[0]
},
Authenticator: func(schemes []AuthenticationScheme, authentication Authentication) Authentication {
Authenticator: func(schemes []AuthenticationScheme, _ Authentication) Authentication {
if contains(schemes, AuthenticationSchemeGuest) {
return &GuestAuthentication{}
}
Expand Down
28 changes: 14 additions & 14 deletions client_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func NewClientChannel(t Transport, bufferSize int) *ClientChannel {
func (c *ClientChannel) receiveSessionFromServer(ctx context.Context) (*Session, error) {
ses, err := c.receiveSession(ctx)
if err != nil {
return nil, fmt.Errorf("receive session failed: %w", err)
return nil, fmt.Errorf("receive session: %w", err)
}

if ses.State == SessionStateEstablished {
Expand Down Expand Up @@ -130,12 +130,12 @@ func (c *ClientChannel) sendFinishingSession(ctx context.Context) error {
}

// CompressionSelector defines a function for selecting the compression for a session.
type CompressionSelector func([]SessionCompression) SessionCompression
type CompressionSelector func(options []SessionCompression) SessionCompression

// EncryptionSelector defines a function for selecting the encryption for a session.
type EncryptionSelector func([]SessionEncryption) SessionEncryption
type EncryptionSelector func(options []SessionEncryption) SessionEncryption

type Authenticator func([]AuthenticationScheme, Authentication) Authentication
type Authenticator func(schemes []AuthenticationScheme, roundTrip Authentication) Authentication

// EstablishSession performs the client session negotiation and authentication handshake.
func (c *ClientChannel) EstablishSession(
Expand All @@ -156,17 +156,17 @@ func (c *ClientChannel) EstablishSession(

ses, err := c.startNewSession(ctx)
if err != nil {
return nil, fmt.Errorf("error establishing the session: %w", err)
return nil, fmt.Errorf("establish session: %w", err)
}

// Session negotiation
if ses.State == SessionStateNegotiating {
if compSelector == nil {
panic("the compression selector should not be nil")
panic("nil compression selector")
}

if encryptSelector == nil {
panic("the encryption selector should not be nil")
panic("nil encrypt selector")
}

// Select options
Expand All @@ -175,28 +175,28 @@ func (c *ClientChannel) EstablishSession(
compSelector(ses.CompressionOptions),
encryptSelector(ses.EncryptionOptions))
if err != nil {
return nil, fmt.Errorf("error establishing the session: %w", err)
return nil, fmt.Errorf("establish session: %w", err)
}

if ses.State == SessionStateNegotiating {
if ses.Compression != "" && ses.Compression != c.transport.Compression() {
err = c.transport.SetCompression(ctx, ses.Compression)
if err != nil {
return nil, fmt.Errorf("error setting the session compression: %w", err)
return nil, fmt.Errorf("establish session: set compression: %w", err)
}
}
if ses.Encryption != "" && ses.Encryption != c.transport.Encryption() {
err = c.transport.SetEncryption(ctx, ses.Encryption)
if err != nil {
return nil, fmt.Errorf("error setting the session encryption: %w", err)
return nil, fmt.Errorf("establish session: set encryption: %w", err)
}
}
}

// Await for authentication options
ses, err = c.receiveSessionFromServer(ctx)
if err != nil {
return nil, fmt.Errorf("error establishing the session: %w", err)
return nil, fmt.Errorf("establish session: %w", err)
}
}

Expand All @@ -211,7 +211,7 @@ func (c *ClientChannel) EstablishSession(
instance,
)
if err != nil {
return nil, fmt.Errorf("error establishing the session: %w", err)
return nil, fmt.Errorf("establish session: %w", err)
}
roundTrip = ses.Authentication
}
Expand All @@ -222,12 +222,12 @@ func (c *ClientChannel) EstablishSession(
// FinishSession performs the session finishing handshake.
func (c *ClientChannel) FinishSession(ctx context.Context) (*Session, error) {
if err := c.sendFinishingSession(ctx); err != nil {
return nil, fmt.Errorf("error sending the finishing session: %w", err)
return nil, fmt.Errorf("finish session: %w", err)
}

ses, err := c.receiveSessionFromServer(ctx)
if err != nil {
return nil, fmt.Errorf("error receiving the finished the session: %w", err)
return nil, fmt.Errorf("finish session: %w", err)
}

return ses, nil
Expand Down
51 changes: 51 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package lime

import (
"context"
"errors"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"log"
"testing"
"time"
)

func TestClient_NewClient_Message(t *testing.T) {
// Arrange
defer goleak.VerifyNone(t)
ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
defer cancel()
addr1 := InProcessAddr("localhost")
msgChan := make(chan *Message, 1)
server := NewServerBuilder().
ListenInProcess(addr1).
MessagesHandlerFunc(
func(ctx context.Context, msg *Message, s Sender) error {
msgChan <- msg
return nil
}).
Build()
defer silentClose(server)
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, ErrServerClosed) {
log.Println(err)
}
}()
config := NewClientConfig()
config.NewTransport = func(ctx context.Context) (Transport, error) {
return DialInProcess(addr1, 1)
}
mux := &EnvelopeMux{}
client := NewClient(config, mux)
msg := createMessage()

// Act
err := client.SendMessage(ctx, msg)

// Assert
assert.NoError(t, err)
rcvMsg := <-msgChan
assert.Equal(t, msg, rcvMsg)
err = client.Close()
assert.NoError(t, err)
}
2 changes: 1 addition & 1 deletion handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (m *EnvelopeMux) listen(ctx context.Context, c *channel) error {
}
}
}
return nil
return ctx.Err()
}

func (m *EnvelopeMux) handleMessage(ctx context.Context, msg *Message, s Sender) error {
Expand Down
6 changes: 3 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func (b *ServerBuilder) ListenWebsocket(addr net.TCPAddr, config *WebsocketConfi

func (b *ServerBuilder) ListenInProcess(addr InProcessAddr) *ServerBuilder {
listener := NewInProcessTransportListener(addr)
b.listeners = append(b.listeners, NewBoundListener(listener, &addr))
b.listeners = append(b.listeners, NewBoundListener(listener, addr))
return b
}

Expand Down Expand Up @@ -408,8 +408,8 @@ func NewBoundListener(listener TransportListener, addr net.Addr) BoundListener {
if listener == nil || reflect.ValueOf(listener).IsNil() {
panic("nil Listener")
}
if addr == nil || reflect.ValueOf(addr).IsNil() {
panic("nil Addr")
if addr == nil || reflect.ValueOf(addr).IsZero() {
panic("zero addr value")
}
return BoundListener{
Listener: listener,
Expand Down

0 comments on commit 253d7bf

Please sign in to comment.