Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Access] Unify subscription id with client message id #6847

111 changes: 63 additions & 48 deletions engine/access/rest/websockets/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"time"

"golang.org/x/time/rate"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -129,7 +129,7 @@ type Controller struct {
// issues such as sending on a closed channel while maintaining proper cleanup.
multiplexedStream chan interface{}

dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider]
dataProviders *concurrentmap.Map[SubscriptionID, dp.DataProvider]
dataProviderFactory dp.DataProviderFactory
dataProvidersGroup *sync.WaitGroup
limiter *rate.Limiter
Expand All @@ -146,7 +146,7 @@ func NewWebSocketController(
config: config,
conn: conn,
multiplexedStream: make(chan interface{}),
dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](),
dataProviders: concurrentmap.New[SubscriptionID, dp.DataProvider](),
dataProviderFactory: dataProviderFactory,
dataProvidersGroup: &sync.WaitGroup{},
limiter: rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1),
Expand Down Expand Up @@ -246,7 +246,7 @@ func (c *Controller) keepalive(ctx context.Context) error {
// If no messages are sent within InactivityTimeout and no active data providers exist,
// the connection will be closed.
func (c *Controller) writeMessages(ctx context.Context) error {
inactivityTicker := time.NewTicker(c.config.InactivityTimeout / 10)
inactivityTicker := time.NewTicker(c.inactivityTickerPeriod())
defer inactivityTicker.Stop()

lastMessageSentAt := time.Now()
Expand Down Expand Up @@ -301,6 +301,10 @@ func (c *Controller) writeMessages(ctx context.Context) error {
}
}

func (c *Controller) inactivityTickerPeriod() time.Duration {
return c.config.InactivityTimeout / 10
}

// readMessages continuously reads messages from a client WebSocket connection,
// validates each message, and processes it based on the message type.
func (c *Controller) readMessages(ctx context.Context) error {
Expand All @@ -314,7 +318,8 @@ func (c *Controller) readMessages(ctx context.Context) error {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidMessage, "error reading message", "", "", ""))
wrapErrorMessage(http.StatusBadRequest, "error reading message", "", ""),
)
continue
}

Expand All @@ -323,7 +328,8 @@ func (c *Controller) readMessages(ctx context.Context) error {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidMessage, "error parsing message", "", "", ""))
wrapErrorMessage(http.StatusBadRequest, "error parsing message", "", ""),
)
continue
}
}
Expand Down Expand Up @@ -366,24 +372,34 @@ func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage)
}

func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) {
subscriptionID, err := c.parseOrCreateSubscriptionID(msg.SubscriptionID)
if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id",
models.SubscribeAction, msg.SubscriptionID),
)
return
}

// register new provider
provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream)
provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID.String(), msg.Topic, msg.Arguments, c.multiplexedStream)
if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidArgument, "error creating data provider", msg.ClientMessageID, models.SubscribeAction, ""),
wrapErrorMessage(http.StatusBadRequest, "error creating data provider",
models.SubscribeAction, subscriptionID.String()),
)
return
}
c.dataProviders.Add(provider.ID(), provider)
c.dataProviders.Add(subscriptionID, provider)

// write OK response to client
responseOk := models.SubscribeMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
ClientMessageID: msg.ClientMessageID,
Success: true,
SubscriptionID: provider.ID().String(),
SubscriptionID: subscriptionID.String(),
},
}
c.writeResponse(ctx, responseOk)
Expand All @@ -396,72 +412,63 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(SubscriptionError, "subscription finished with error", "", "", ""),
wrapErrorMessage(http.StatusInternalServerError, "internal error",
models.SubscribeAction, subscriptionID.String()),
)
}

c.dataProvidersGroup.Done()
c.dataProviders.Remove(provider.ID())
c.dataProviders.Remove(subscriptionID)
}()
}

func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) {
id, err := uuid.Parse(msg.SubscriptionID)
subscriptionID, err := ParseClientSubscriptionID(msg.SubscriptionID)
if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(InvalidArgument, "error parsing subscription ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID),
wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id",
models.UnsubscribeAction, msg.SubscriptionID),
)
return
}

provider, ok := c.dataProviders.Get(id)
provider, ok := c.dataProviders.Get(subscriptionID)
if !ok {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(NotFound, "subscription not found", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID),
wrapErrorMessage(http.StatusNotFound, "subscription not found",
models.UnsubscribeAction, subscriptionID.String()),
)
return
}

provider.Close()
c.dataProviders.Remove(id)
c.dataProviders.Remove(subscriptionID)

responseOk := models.UnsubscribeMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
ClientMessageID: msg.ClientMessageID,
Success: true,
SubscriptionID: msg.SubscriptionID,
SubscriptionID: subscriptionID.String(),
},
}
c.writeResponse(ctx, responseOk)
}

func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) {
func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListSubscriptionsMessageRequest) {
var subs []*models.SubscriptionEntry
err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error {
_ = c.dataProviders.ForEach(func(id SubscriptionID, provider dp.DataProvider) error {
subs = append(subs, &models.SubscriptionEntry{
ID: id.String(),
Topic: provider.Topic(),
SubscriptionID: id.String(),
Topic: provider.Topic(),
})
return nil
})

if err != nil {
c.writeErrorResponse(
ctx,
err,
wrapErrorMessage(NotFound, "error listing subscriptions", msg.ClientMessageID, models.ListSubscriptionsAction, ""),
)
return
}

responseOk := models.ListSubscriptionsMessageResponse{
Success: true,
ClientMessageID: msg.ClientMessageID,
Subscriptions: subs,
Subscriptions: subs,
Action: models.ListSubscriptionsAction,
}
c.writeResponse(ctx, responseOk)
}
Expand All @@ -472,13 +479,10 @@ func (c *Controller) shutdownConnection() {
c.logger.Debug().Err(err).Msg("error closing connection")
}

err = c.dataProviders.ForEach(func(_ uuid.UUID, provider dp.DataProvider) error {
_ = c.dataProviders.ForEach(func(_ SubscriptionID, provider dp.DataProvider) error {
provider.Close()
return nil
})
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
c.logger.Debug().Err(err).Msg("error closing data provider")
}

c.dataProviders.Clear()
c.dataProvidersGroup.Wait()
Expand All @@ -498,15 +502,26 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) {
}
}

func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse {
func wrapErrorMessage(code int, message string, action string, subscriptionID string) models.BaseMessageResponse {
return models.BaseMessageResponse{
ClientMessageID: msgId,
Success: false,
SubscriptionID: subscriptionID,
SubscriptionID: subscriptionID,
Error: models.ErrorMessage{
Code: int(code),
Code: code,
Message: message,
Action: action,
},
Action: action,
}
}

func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, error) {
newId, err := NewSubscriptionID(id)
if err != nil {
return SubscriptionID{}, err
}

if c.dataProviders.Has(newId) {
return SubscriptionID{}, fmt.Errorf("subscription ID is already in use: %s", newId)
}

return newId, nil
}
Loading
Loading