From e5a91d70c4591fcf6f521047381146a517990c2c Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 3 Jan 2025 14:59:54 +0200 Subject: [PATCH 01/10] Unify subscription id with client message id From now on, we use only 1 id in request/response messages. This id is called `subscription_id`. A client may provide `subscription_id` in `subscribe` request. If client does not provide it, we generate it ourselves. Clients that use browsers or other async environemnts may use `subscription_id` to correlate response messages with the request ones. `subscription_id` is used in all messages related to subscription. I also remove `success` field from response. We include `subscription_id` field in a resposne in case of OK response. In case of error response, we include `error` field. --- engine/access/rest/websockets/controller.go | 76 +++++++---- .../access/rest/websockets/controller_test.go | 125 +++++++----------- .../data_providers/base_provider.go | 9 -- .../data_providers/data_provider.go | 8 +- .../data_providers/events_provider.go | 2 +- .../data_providers/mock/data_provider.go | 25 +--- engine/access/rest/websockets/error_codes.go | 2 +- .../rest/websockets/models/account_models.go | 2 +- .../rest/websockets/models/base_message.go | 17 ++- .../rest/websockets/models/error_message.go | 7 - .../rest/websockets/models/event_models.go | 2 +- .../websockets/models/list_subscriptions.go | 6 +- .../websockets/models/subscribe_message.go | 2 +- .../websockets/models/subscription_entry.go | 5 +- .../websockets/models/unsubscribe_message.go | 1 - 15 files changed, 122 insertions(+), 167 deletions(-) delete mode 100644 engine/access/rest/websockets/models/error_message.go diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 15c187fc650..99531bcc59a 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -306,7 +306,8 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error reading message", "", "", "")) + wrapErrorMessage(InvalidMessage, "error reading message", ""), + ) continue } @@ -315,7 +316,8 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error parsing message", "", "", "")) + wrapErrorMessage(InvalidMessage, "error parsing message", ""), + ) continue } } @@ -358,24 +360,32 @@ func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage) } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { + subscriptionID, err := c.parseOrCreateSubscriptionID(msg.SubscriptionID) + if err != nil { + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidArgument, "error parsing subscription id", msg.SubscriptionID), + ) + return + } + // register new provider provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error creating data provider", msg.ClientMessageID, models.SubscribeAction, ""), + wrapErrorMessage(InvalidArgument, "error creating data provider", subscriptionID.String()), ) return } - c.dataProviders.Add(provider.ID(), provider) + c.dataProviders.Add(subscriptionID, provider) // write OK response to client responseOk := models.SubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - ClientMessageID: msg.ClientMessageID, - Success: true, - SubscriptionID: provider.ID().String(), + SubscriptionID: subscriptionID.String(), }, } c.writeResponse(ctx, responseOk) @@ -388,44 +398,42 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(SubscriptionError, "subscription finished with error", "", "", ""), + wrapErrorMessage(InternalError, "internal error", subscriptionID.String()), ) } c.dataProvidersGroup.Done() - c.dataProviders.Remove(provider.ID()) + c.dataProviders.Remove(subscriptionID) }() } func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { - id, err := uuid.Parse(msg.SubscriptionID) + subscriptionID, err := uuid.Parse(msg.SubscriptionID) if err != nil { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error parsing subscription ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(InvalidArgument, "error parsing subscription id", msg.SubscriptionID), ) return } - provider, ok := c.dataProviders.Get(id) + provider, ok := c.dataProviders.Get(subscriptionID) if !ok { c.writeErrorResponse( ctx, err, - wrapErrorMessage(NotFound, "subscription not found", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(NotFound, "subscription not found", subscriptionID.String()), ) return } provider.Close() - c.dataProviders.Remove(id) + c.dataProviders.Remove(subscriptionID) responseOk := models.UnsubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - ClientMessageID: msg.ClientMessageID, - Success: true, - SubscriptionID: msg.SubscriptionID, + SubscriptionID: subscriptionID.String(), }, } c.writeResponse(ctx, responseOk) @@ -445,15 +453,16 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis c.writeErrorResponse( ctx, err, - wrapErrorMessage(NotFound, "error listing subscriptions", msg.ClientMessageID, models.ListSubscriptionsAction, ""), + wrapErrorMessage(NotFound, "error listing subscriptions", ""), ) return } responseOk := models.ListSubscriptionsMessageResponse{ - Success: true, - ClientMessageID: msg.ClientMessageID, - Subscriptions: subs, + BaseMessageResponse: models.BaseMessageResponse{ + SubscriptionID: msg.SubscriptionID, + }, + Subscriptions: subs, } c.writeResponse(ctx, responseOk) } @@ -490,15 +499,30 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) { } } -func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse { +func wrapErrorMessage(code Code, message string, subscriptionID string) models.BaseMessageResponse { return models.BaseMessageResponse{ - ClientMessageID: msgId, - Success: false, - SubscriptionID: subscriptionID, + SubscriptionID: subscriptionID, Error: models.ErrorMessage{ Code: int(code), Message: message, - Action: action, }, } } + +func (c *Controller) parseOrCreateSubscriptionID(id string) (uuid.UUID, error) { + // if client didn't provide subscription id, we create one for him + if id == "" { + return uuid.New(), nil + } + + newID, err := uuid.Parse(id) + if err != nil { + return uuid.Nil, err + } + + if c.dataProviders.Has(newID) { + return uuid.Nil, fmt.Errorf("subscription id is already in use") + } + + return newID, nil +} diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 1a52f79b516..6d430ffddc0 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -55,10 +55,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -71,8 +68,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.SubscribeAction, + SubscriptionID: uuid.New().String(), + Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, Arguments: nil, @@ -98,9 +95,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) - require.True(t, response.Success) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) - require.Equal(t, id.String(), response.SubscriptionID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) return websocket.ErrCloseSent }) @@ -148,7 +143,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) require.Equal(t, int(InvalidMessage), response.Error.Code) return websocket.ErrCloseSent @@ -174,7 +168,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}) - s.expectSubscribeRequest(t, conn) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) conn. On("WriteJSON", mock.Anything). @@ -183,7 +178,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) require.Equal(t, int(InvalidArgument), response.Error.Code) @@ -204,7 +198,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) - dataProvider.On("ID").Return(uuid.New()) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -219,8 +212,9 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) conn. On("WriteJSON", mock.Anything). @@ -229,9 +223,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, int(SubscriptionError), response.Error.Code) + require.Equal(t, int(InternalError), response.Error.Code) return websocket.ErrCloseSent }) @@ -258,10 +251,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -272,15 +262,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + SubscriptionID: subscriptionID, + Action: models.UnsubscribeAction, }, - SubscriptionID: id.String(), } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -302,9 +292,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { response, ok := msg.(models.UnsubscribeMessageResponse) require.True(t, ok) - require.True(t, response.Success) require.Empty(t, response.Error) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) require.Equal(t, request.SubscriptionID, response.SubscriptionID) return websocket.ErrCloseSent @@ -331,10 +320,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -345,15 +331,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + SubscriptionID: "invalid-uuid", + Action: models.UnsubscribeAction, }, - SubscriptionID: "invalid-uuid", } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -375,9 +361,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) require.Equal(t, int(InvalidArgument), response.Error.Code) return websocket.ErrCloseSent @@ -404,10 +389,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -418,15 +400,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + SubscriptionID: uuid.New().String(), // unknown subscription id + Action: models.UnsubscribeAction, }, - SubscriptionID: uuid.New().String(), } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -448,10 +430,9 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) - require.NotEmpty(t, response.Error) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.NotEmpty(t, response.Error) require.Equal(t, int(NotFound), response.Error.Code) return websocket.ErrCloseSent @@ -481,9 +462,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { done := make(chan struct{}) - id := uuid.New() topic := dp.BlocksTopic - dataProvider.On("ID").Return(id) dataProvider.On("Topic").Return(topic) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -495,13 +474,14 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.ListSubscriptionsMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.ListSubscriptionsAction, + SubscriptionID: "", + Action: models.ListSubscriptionsAction, }, } requestJson, err := json.Marshal(request) @@ -524,11 +504,10 @@ func (s *WsControllerSuite) TestListSubscriptions() { response, ok := msg.(models.ListSubscriptionsMessageResponse) require.True(t, ok) - require.True(t, response.Success) require.Empty(t, response.Error) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Empty(t, response.SubscriptionID) require.Equal(t, 1, len(response.Subscriptions)) - require.Equal(t, id.String(), response.Subscriptions[0].ID) + require.Equal(t, subscriptionID, response.Subscriptions[0].ID) require.Equal(t, topic, response.Subscriptions[0].Topic) return websocket.ErrCloseSent @@ -558,8 +537,6 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Return(dataProvider, nil). Once() - id := uuid.New() - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -574,8 +551,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -613,8 +591,6 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Return(dataProvider, nil). Once() - id := uuid.New() - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -631,8 +607,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -752,8 +729,6 @@ func (s *WsControllerSuite) TestControllerShutdown() { Return(dataProvider, nil). Once() - id := uuid.New() - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -766,8 +741,9 @@ func (s *WsControllerSuite) TestControllerShutdown() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := uuid.New().String() + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) conn. On("WriteJSON", mock.Anything). @@ -927,11 +903,11 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da } // expectSubscribeRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection) string { +func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection, subscriptionID string) { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.SubscribeAction, + SubscriptionID: subscriptionID, + Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, } @@ -948,19 +924,16 @@ func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock. }). Return(nil). Once() - - return request.ClientMessageID } // expectSubscribeResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock.WebsocketConnection, msgId string) { +func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock.WebsocketConnection, subscriptionID string) { conn. On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(t, ok) - require.Equal(t, msgId, response.ClientMessageID) - require.Equal(t, true, response.Success) + require.Equal(t, subscriptionID, response.SubscriptionID) }). Return(nil). Once() diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index 0ee040cd4ac..9f9202deeff 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -3,14 +3,11 @@ package data_providers import ( "context" - "github.com/google/uuid" - "github.com/onflow/flow-go/engine/access/subscription" ) // baseDataProvider holds common objects for the provider type baseDataProvider struct { - id uuid.UUID topic string cancel context.CancelFunc send chan<- interface{} @@ -25,7 +22,6 @@ func newBaseDataProvider( subscription subscription.Subscription, ) *baseDataProvider { return &baseDataProvider{ - id: uuid.New(), topic: topic, cancel: cancel, send: send, @@ -33,11 +29,6 @@ func newBaseDataProvider( } } -// ID returns the unique identifier of the data provider. -func (b *baseDataProvider) ID() uuid.UUID { - return b.id -} - // Topic returns the topic associated with the data provider. func (b *baseDataProvider) Topic() string { return b.topic diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go index ab48ebeb9f9..321d7ba1391 100644 --- a/engine/access/rest/websockets/data_providers/data_provider.go +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -1,14 +1,8 @@ package data_providers -import ( - "github.com/google/uuid" -) - // The DataProvider is the interface abstracts of the actual data provider used by the WebSocketCollector. -// It provides methods for retrieving the provider's unique ID, topic, and a methods to close and run the provider. +// It provides methods for retrieving the provider's unique SubscriptionID, topic, and a methods to close and run the provider. type DataProvider interface { - // ID returns the unique identifier of the data provider. - ID() uuid.UUID // Topic returns the topic associated with the data provider. Topic() string // Close terminates the data provider. diff --git a/engine/access/rest/websockets/data_providers/events_provider.go b/engine/access/rest/websockets/data_providers/events_provider.go index 318e8081d2c..bf2b867bcb3 100644 --- a/engine/access/rest/websockets/data_providers/events_provider.go +++ b/engine/access/rest/websockets/data_providers/events_provider.go @@ -103,7 +103,7 @@ func (p *EventsDataProvider) handleResponse() func(eventsResponse *backend.Event } p.send <- &models.EventResponse{ - BlockId: eventsResponse.BlockID.String(), + BlockID: eventsResponse.BlockID.String(), BlockHeight: strconv.FormatUint(eventsResponse.Height, 10), BlockTimestamp: eventsResponse.BlockTimestamp, Events: eventsResponse.Events, diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go index 48debb23ae3..c2bedf29bac 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -2,10 +2,7 @@ package mock -import ( - uuid "github.com/google/uuid" - mock "github.com/stretchr/testify/mock" -) +import mock "github.com/stretchr/testify/mock" // DataProvider is an autogenerated mock type for the DataProvider type type DataProvider struct { @@ -17,26 +14,6 @@ func (_m *DataProvider) Close() { _m.Called() } -// ID provides a mock function with given fields: -func (_m *DataProvider) ID() uuid.UUID { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for ID") - } - - var r0 uuid.UUID - if rf, ok := ret.Get(0).(func() uuid.UUID); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(uuid.UUID) - } - } - - return r0 -} - // Run provides a mock function with given fields: func (_m *DataProvider) Run() error { ret := _m.Called() diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go index fd206bed0b3..446c6b5f6f8 100644 --- a/engine/access/rest/websockets/error_codes.go +++ b/engine/access/rest/websockets/error_codes.go @@ -6,5 +6,5 @@ const ( InvalidMessage Code = iota InvalidArgument NotFound - SubscriptionError + InternalError ) diff --git a/engine/access/rest/websockets/models/account_models.go b/engine/access/rest/websockets/models/account_models.go index fdb6826b4f1..0243d4c927d 100644 --- a/engine/access/rest/websockets/models/account_models.go +++ b/engine/access/rest/websockets/models/account_models.go @@ -4,7 +4,7 @@ import "github.com/onflow/flow-go/model/flow" // AccountStatusesResponse is the response message for 'events' topic. type AccountStatusesResponse struct { - BlockID string `json:"blockID"` + BlockID string `json:"block_id"` Height string `json:"height"` AccountEvents map[string]flow.EventsList `json:"account_events"` MessageIndex uint64 `json:"message_index"` diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index cdcd72eb1ed..3b7e5e50745 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -8,14 +8,19 @@ const ( // BaseMessageRequest represents a base structure for incoming messages. type BaseMessageRequest struct { - Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions - ClientMessageID string `json:"message_id"` // ClientMessageID is a uuid generated by client to identify request/response uniquely + // SubscriptionID is UUID generated by either client or server to uniquely identify subscription. + // It is empty for 'list_subscription' action + SubscriptionID string `json:"subscription_id,omitempty"` + Action string `json:"action"` // Action is an action to perform (e.g. 'subscribe' to some data) } // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { - SubscriptionID string `json:"subscription_id"` - ClientMessageID string `json:"message_id,omitempty"` // ClientMessageID may be empty in case we send msg by ourselves (e.g. error occurred) - Success bool `json:"success"` - Error ErrorMessage `json:"error,omitempty"` + SubscriptionID string `json:"subscription_id,omitempty"` // SubscriptionID might be empty in case of error response + Error ErrorMessage `json:"error,omitempty"` // Error might be empty in case of OK response +} + +type ErrorMessage struct { + Code int `json:"code"` // Code is an error code that categorizes an error + Message string `json:"message"` } diff --git a/engine/access/rest/websockets/models/error_message.go b/engine/access/rest/websockets/models/error_message.go deleted file mode 100644 index d5c0670926f..00000000000 --- a/engine/access/rest/websockets/models/error_message.go +++ /dev/null @@ -1,7 +0,0 @@ -package models - -type ErrorMessage struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action,omitempty"` -} diff --git a/engine/access/rest/websockets/models/event_models.go b/engine/access/rest/websockets/models/event_models.go index 0659cbc6937..eb956a6c0e9 100644 --- a/engine/access/rest/websockets/models/event_models.go +++ b/engine/access/rest/websockets/models/event_models.go @@ -8,7 +8,7 @@ import ( // EventResponse is the response message for 'events' topic. type EventResponse struct { - BlockId string `json:"block_id"` + BlockID string `json:"block_id"` BlockHeight string `json:"block_height"` BlockTimestamp time.Time `json:"block_timestamp"` Events []flow.Event `json:"events"` diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go index 4893a34b09d..ba4fcc9cb4b 100644 --- a/engine/access/rest/websockets/models/list_subscriptions.go +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -8,8 +8,6 @@ type ListSubscriptionsMessageRequest struct { // ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. // It contains a list of active subscriptions for the current WebSocket connection. type ListSubscriptionsMessageResponse struct { - ClientMessageID string `json:"message_id"` - Success bool `json:"success"` - Error ErrorMessage `json:"error,omitempty"` - Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` + BaseMessageResponse + Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` // Subscriptions might be empty in case of no active subscriptions } diff --git a/engine/access/rest/websockets/models/subscribe_message.go b/engine/access/rest/websockets/models/subscribe_message.go index 532e4c6a987..b4bd9e871da 100644 --- a/engine/access/rest/websockets/models/subscribe_message.go +++ b/engine/access/rest/websockets/models/subscribe_message.go @@ -6,7 +6,7 @@ type Arguments map[string]interface{} type SubscribeMessageRequest struct { BaseMessageRequest Topic string `json:"topic"` // Topic to subscribe to - Arguments Arguments `json:"arguments"` // Additional arguments for subscription + Arguments Arguments `json:"arguments"` // Arguments are the arguments for the subscribed topic } // SubscribeMessageResponse represents the response to a subscription request. diff --git a/engine/access/rest/websockets/models/subscription_entry.go b/engine/access/rest/websockets/models/subscription_entry.go index d3f2b352bb7..d38b39b24ee 100644 --- a/engine/access/rest/websockets/models/subscription_entry.go +++ b/engine/access/rest/websockets/models/subscription_entry.go @@ -2,6 +2,7 @@ package models // SubscriptionEntry represents an active subscription entry. type SubscriptionEntry struct { - Topic string `json:"topic,omitempty"` // Topic of the subscription - ID string `json:"id,omitempty"` // Unique subscription ID + ID string `json:"id"` // ID is a client generated UUID for subscription + Topic string `json:"topic"` // Topic of the subscription + //TODO: maybe we should add arguments for readability? } diff --git a/engine/access/rest/websockets/models/unsubscribe_message.go b/engine/access/rest/websockets/models/unsubscribe_message.go index 1402189a601..4a8283f2392 100644 --- a/engine/access/rest/websockets/models/unsubscribe_message.go +++ b/engine/access/rest/websockets/models/unsubscribe_message.go @@ -3,7 +3,6 @@ package models // UnsubscribeMessageRequest represents a request to unsubscribe from a topic. type UnsubscribeMessageRequest struct { BaseMessageRequest - SubscriptionID string `json:"id"` } // UnsubscribeMessageResponse represents the response to an unsubscription request. From 5a4f44849366f6e7a8181f77f7c3e8f6265de621 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 7 Jan 2025 17:42:27 +0200 Subject: [PATCH 02/10] Fix comments. Fix race condition in inactivity tracker test --- engine/access/rest/websockets/controller.go | 20 ++++++----- .../access/rest/websockets/controller_test.go | 35 +++++++++---------- .../account_statuses_provider.go | 3 ++ .../data_providers/base_provider.go | 21 ++++++----- .../data_providers/block_digests_provider.go | 3 ++ .../data_providers/block_headers_provider.go | 3 ++ .../data_providers/blocks_provider.go | 3 ++ .../data_providers/blocks_provider_test.go | 3 +- .../data_providers/events_provider.go | 3 ++ .../rest/websockets/data_providers/factory.go | 18 +++++----- .../mock/data_provider_factory.go | 20 ++++++----- ...d_and_get_transaction_statuses_provider.go | 3 ++ .../transaction_statuses_provider.go | 3 ++ .../websockets/data_providers/utittest.go | 3 +- engine/access/rest/websockets/error_codes.go | 7 ++-- .../websockets/models/subscription_entry.go | 6 ++-- .../websockets/models/unsubscribe_message.go | 2 ++ 17 files changed, 96 insertions(+), 60 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 99531bcc59a..8fe04dd8e87 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -242,7 +242,7 @@ func (c *Controller) keepalive(ctx context.Context) error { // If no messages are sent within InactivityTimeout and no active data providers exist, // the connection will be closed. func (c *Controller) writeMessages(ctx context.Context) error { - inactivityTicker := time.NewTicker(c.config.InactivityTimeout / 10) + inactivityTicker := time.NewTicker(c.inactivityTickerPeriod()) defer inactivityTicker.Stop() lastMessageSentAt := time.Now() @@ -293,6 +293,10 @@ func (c *Controller) writeMessages(ctx context.Context) error { } } +func (c *Controller) inactivityTickerPeriod() time.Duration { + return c.config.InactivityTimeout / 10 +} + // readMessages continuously reads messages from a client WebSocket connection, // validates each message, and processes it based on the message type. func (c *Controller) readMessages(ctx context.Context) error { @@ -365,18 +369,18 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error parsing subscription id", msg.SubscriptionID), + wrapErrorMessage(InvalidMessage, "error parsing subscription id", msg.SubscriptionID), ) return } // register new provider - provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) + provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error creating data provider", subscriptionID.String()), + wrapErrorMessage(InvalidMessage, "error creating data provider", subscriptionID.String()), ) return } @@ -398,7 +402,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(InternalError, "internal error", subscriptionID.String()), + wrapErrorMessage(InternalServerError, "internal error", subscriptionID.String()), ) } @@ -413,7 +417,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error parsing subscription id", msg.SubscriptionID), + wrapErrorMessage(InvalidMessage, "error parsing subscription id", msg.SubscriptionID), ) return } @@ -443,8 +447,8 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis var subs []*models.SubscriptionEntry err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { subs = append(subs, &models.SubscriptionEntry{ - ID: id.String(), - Topic: provider.Topic(), + SubscriptionID: id.String(), + Topic: provider.Topic(), }) return nil }) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 6d430ffddc0..ace74b8c8e1 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -51,7 +51,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -163,7 +163,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil, fmt.Errorf("error creating data provider")). Once() @@ -179,7 +179,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) require.NotEmpty(t, response.Error) - require.Equal(t, int(InvalidArgument), response.Error.Code) + require.Equal(t, int(InvalidMessage), response.Error.Code) return websocket.ErrCloseSent }) @@ -207,7 +207,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -224,7 +224,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) require.NotEmpty(t, response.Error) - require.Equal(t, int(InternalError), response.Error.Code) + require.Equal(t, int(InternalServerError), response.Error.Code) return websocket.ErrCloseSent }) @@ -247,7 +247,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -316,7 +316,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -363,7 +363,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.True(t, ok) require.NotEmpty(t, response.Error) require.Equal(t, request.SubscriptionID, response.SubscriptionID) - require.Equal(t, int(InvalidArgument), response.Error.Code) + require.Equal(t, int(InvalidMessage), response.Error.Code) return websocket.ErrCloseSent }). @@ -385,7 +385,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -456,7 +456,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -507,7 +507,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { require.Empty(t, response.Error) require.Empty(t, response.SubscriptionID) require.Equal(t, 1, len(response.Subscriptions)) - require.Equal(t, subscriptionID, response.Subscriptions[0].ID) + require.Equal(t, subscriptionID, response.Subscriptions[0].SubscriptionID) require.Equal(t, topic, response.Subscriptions[0].Topic) return websocket.ErrCloseSent @@ -533,7 +533,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -587,7 +587,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -725,7 +725,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() @@ -803,15 +803,14 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn. On("ReadJSON", mock.Anything). Return(func(interface{}) error { - // waiting more than InactivityTimeout to make sure that read message routine busy and do not return - // an error before than inactivity tracker initiate shut down - <-time.After(wsConfig.InactivityTimeout) + // make sure the reader routine sleeps for more time than InactivityTimeout + inactivity ticker period. + // meanwhile, the writer routine must shut down the controller. + <-time.After(wsConfig.InactivityTimeout + controller.inactivityTickerPeriod()*2) return websocket.ErrCloseSent }). Once() controller.HandleConnection(context.Background()) - time.Sleep(wsConfig.InactivityTimeout) conn.AssertExpectations(t) }) diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider.go b/engine/access/rest/websockets/data_providers/account_statuses_provider.go index 396dcbc7b9a..5b3c112fd09 100644 --- a/engine/access/rest/websockets/data_providers/account_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" + "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -42,6 +43,7 @@ func NewAccountStatusesDataProvider( ctx context.Context, logger zerolog.Logger, stateStreamApi state_stream.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -64,6 +66,7 @@ func NewAccountStatusesDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index 9f9202deeff..3adc514a924 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -3,29 +3,34 @@ package data_providers import ( "context" + "github.com/google/uuid" + "github.com/onflow/flow-go/engine/access/subscription" ) // baseDataProvider holds common objects for the provider type baseDataProvider struct { - topic string - cancel context.CancelFunc - send chan<- interface{} - subscription subscription.Subscription + subscriptionID uuid.UUID + topic string + cancel context.CancelFunc + send chan<- interface{} + subscription subscription.Subscription } // newBaseDataProvider creates a new instance of baseDataProvider. func newBaseDataProvider( + subscriptionID uuid.UUID, topic string, cancel context.CancelFunc, send chan<- interface{}, subscription subscription.Subscription, ) *baseDataProvider { return &baseDataProvider{ - topic: topic, - cancel: cancel, - send: send, - subscription: subscription, + subscriptionID: subscriptionID, + topic: topic, + cancel: cancel, + send: send, + subscription: subscription, } } diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider.go b/engine/access/rest/websockets/data_providers/block_digests_provider.go index 80307be6b64..e00f164972e 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -28,6 +29,7 @@ func NewBlockDigestsDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -45,6 +47,7 @@ func NewBlockDigestsDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider.go b/engine/access/rest/websockets/data_providers/block_headers_provider.go index 4fddeb499f2..582b9cb0924 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -28,6 +29,7 @@ func NewBlockHeadersDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -45,6 +47,7 @@ func NewBlockHeadersDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/blocks_provider.go b/engine/access/rest/websockets/data_providers/blocks_provider.go index 6c09c4a623a..49c5ee43cd2 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -37,6 +38,7 @@ func NewBlocksDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -54,6 +56,7 @@ func NewBlocksDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/blocks_provider_test.go b/engine/access/rest/websockets/data_providers/blocks_provider_test.go index 85136ae5819..5c688853d02 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider_test.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider_test.go @@ -6,6 +6,7 @@ import ( "strconv" "testing" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -120,7 +121,7 @@ func (s *BlocksProviderSuite) TestBlocksDataProvider_InvalidArguments() { for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlocksDataProvider(ctx, s.log, s.api, BlocksTopic, test.arguments, send) + provider, err := NewBlocksDataProvider(ctx, s.log, s.api, uuid.New(), BlocksTopic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/events_provider.go b/engine/access/rest/websockets/data_providers/events_provider.go index bf2b867bcb3..2f22ff6cbaf 100644 --- a/engine/access/rest/websockets/data_providers/events_provider.go +++ b/engine/access/rest/websockets/data_providers/events_provider.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/common/parser" @@ -41,6 +42,7 @@ func NewEventsDataProvider( ctx context.Context, logger zerolog.Logger, stateStreamApi state_stream.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -63,6 +65,7 @@ func NewEventsDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index a29367b8a0b..b8b75c37ab6 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -32,7 +33,7 @@ type DataProviderFactory interface { // and configuration parameters. // // No errors are expected during normal operations. - NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) + NewDataProvider(ctx context.Context, subscriptionID uuid.UUID, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) } var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) @@ -88,25 +89,26 @@ func NewDataProviderFactory( // No errors are expected during normal operations. func (s *DataProviderFactoryImpl) NewDataProvider( ctx context.Context, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, ch chan<- interface{}, ) (DataProvider, error) { switch topic { case BlocksTopic: - return NewBlocksDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewBlocksDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) case BlockHeadersTopic: - return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) case BlockDigestsTopic: - return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) case EventsTopic: - return NewEventsDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + return NewEventsDataProvider(ctx, s.logger, s.stateStreamApi, subscriptionID, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) case AccountStatusesTopic: - return NewAccountStatusesDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + return NewAccountStatusesDataProvider(ctx, s.logger, s.stateStreamApi, subscriptionID, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) case TransactionStatusesTopic: - return NewTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) case SendAndGetTransactionStatusesTopic: - return NewSendAndGetTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewSendAndGetTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) default: return nil, fmt.Errorf("unsupported topic \"%s\"", topic) } diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go index af49cb4e687..7c7d4bc58c0 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" models "github.com/onflow/flow-go/engine/access/rest/websockets/models" + + uuid "github.com/google/uuid" ) // DataProviderFactory is an autogenerated mock type for the DataProviderFactory type @@ -16,9 +18,9 @@ type DataProviderFactory struct { mock.Mock } -// NewDataProvider provides a mock function with given fields: ctx, topic, args, ch -func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { - ret := _m.Called(ctx, topic, args, ch) +// NewDataProvider provides a mock function with given fields: ctx, subscriptionID, topic, args, ch +func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, subscriptionID uuid.UUID, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { + ret := _m.Called(ctx, subscriptionID, topic, args, ch) if len(ret) == 0 { panic("no return value specified for NewDataProvider") @@ -26,19 +28,19 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string var r0 data_providers.DataProvider var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { - return rf(ctx, topic, args, ch) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { + return rf(ctx, subscriptionID, topic, args, ch) } - if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { - r0 = rf(ctx, topic, args, ch) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { + r0 = rf(ctx, subscriptionID, topic, args, ch) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(data_providers.DataProvider) } } - if rf, ok := ret.Get(1).(func(context.Context, string, models.Arguments, chan<- interface{}) error); ok { - r1 = rf(ctx, topic, args, ch) + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID, string, models.Arguments, chan<- interface{}) error); ok { + r1 = rf(ctx, subscriptionID, topic, args, ch) } else { r1 = ret.Error(1) } diff --git a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go index f6db73ac4e0..dceaaf5899e 100644 --- a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -37,6 +38,7 @@ func NewSendAndGetTransactionStatusesDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -55,6 +57,7 @@ func NewSendAndGetTransactionStatusesDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go index 3b75bde006a..e18a9d3a1f0 100644 --- a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -40,6 +41,7 @@ func NewTransactionStatusesDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID uuid.UUID, topic string, arguments models.Arguments, send chan<- interface{}, @@ -58,6 +60,7 @@ func NewTransactionStatusesDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/utittest.go b/engine/access/rest/websockets/data_providers/utittest.go index c20c3b2f957..ced0b549e03 100644 --- a/engine/access/rest/websockets/data_providers/utittest.go +++ b/engine/access/rest/websockets/data_providers/utittest.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/onflow/flow-go/engine/access/rest/websockets/models" @@ -63,7 +64,7 @@ func testHappyPath[T any]( test.setupBackend(sub) // Create the data provider instance - provider, err := factory.NewDataProvider(ctx, topic, test.arguments, send) + provider, err := factory.NewDataProvider(ctx, uuid.New(), topic, test.arguments, send) require.NotNil(t, provider) require.NoError(t, err) diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go index 446c6b5f6f8..a35deea3bbc 100644 --- a/engine/access/rest/websockets/error_codes.go +++ b/engine/access/rest/websockets/error_codes.go @@ -3,8 +3,7 @@ package websockets type Code int const ( - InvalidMessage Code = iota - InvalidArgument - NotFound - InternalError + InvalidMessage Code = 400 + NotFound Code = 404 + InternalServerError Code = 500 ) diff --git a/engine/access/rest/websockets/models/subscription_entry.go b/engine/access/rest/websockets/models/subscription_entry.go index d38b39b24ee..9a60ab1a0d9 100644 --- a/engine/access/rest/websockets/models/subscription_entry.go +++ b/engine/access/rest/websockets/models/subscription_entry.go @@ -2,7 +2,7 @@ package models // SubscriptionEntry represents an active subscription entry. type SubscriptionEntry struct { - ID string `json:"id"` // ID is a client generated UUID for subscription - Topic string `json:"topic"` // Topic of the subscription - //TODO: maybe we should add arguments for readability? + SubscriptionID string `json:"subscription_id"` // ID is a client generated UUID for subscription + Topic string `json:"topic"` // Topic of the subscription + Arguments Arguments `json:"arguments"` } diff --git a/engine/access/rest/websockets/models/unsubscribe_message.go b/engine/access/rest/websockets/models/unsubscribe_message.go index 4a8283f2392..ca81fbb31a0 100644 --- a/engine/access/rest/websockets/models/unsubscribe_message.go +++ b/engine/access/rest/websockets/models/unsubscribe_message.go @@ -3,6 +3,8 @@ package models // UnsubscribeMessageRequest represents a request to unsubscribe from a topic. type UnsubscribeMessageRequest struct { BaseMessageRequest + //TODO: in this request, subscription_id is mandatory, but we inherit the optional one. + // should we rewrite args to meet requirements? } // UnsubscribeMessageResponse represents the response to an unsubscription request. From ab32fe7b38a76323bf056782f1d23648892eb3a9 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 7 Jan 2025 17:45:49 +0200 Subject: [PATCH 03/10] add ID() method for data provider again --- .../access/rest/websockets/data_providers/base_provider.go | 5 +++++ .../access/rest/websockets/data_providers/data_provider.go | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index 3adc514a924..8936a883161 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -34,6 +34,11 @@ func newBaseDataProvider( } } +// ID returns the subscription ID associated with current data provider +func (b *baseDataProvider) ID() uuid.UUID { + return b.subscriptionID +} + // Topic returns the topic associated with the data provider. func (b *baseDataProvider) Topic() string { return b.topic diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go index 321d7ba1391..ed6c11b0f0d 100644 --- a/engine/access/rest/websockets/data_providers/data_provider.go +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -1,8 +1,14 @@ package data_providers +import ( + "github.com/google/uuid" +) + // The DataProvider is the interface abstracts of the actual data provider used by the WebSocketCollector. // It provides methods for retrieving the provider's unique SubscriptionID, topic, and a methods to close and run the provider. type DataProvider interface { + // ID returns the unique identifier of the data provider. + ID() uuid.UUID // Topic returns the topic associated with the data provider. Topic() string // Close terminates the data provider. From 9f4e9d6f1be6231ba7695bf15118f47271088efc Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 7 Jan 2025 17:52:06 +0200 Subject: [PATCH 04/10] regenerate mocks --- .../data_providers/mock/data_provider.go | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go index c2bedf29bac..48debb23ae3 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -2,7 +2,10 @@ package mock -import mock "github.com/stretchr/testify/mock" +import ( + uuid "github.com/google/uuid" + mock "github.com/stretchr/testify/mock" +) // DataProvider is an autogenerated mock type for the DataProvider type type DataProvider struct { @@ -14,6 +17,26 @@ func (_m *DataProvider) Close() { _m.Called() } +// ID provides a mock function with given fields: +func (_m *DataProvider) ID() uuid.UUID { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ID") + } + + var r0 uuid.UUID + if rf, ok := ret.Get(0).(func() uuid.UUID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(uuid.UUID) + } + } + + return r0 +} + // Run provides a mock function with given fields: func (_m *DataProvider) Run() error { ret := _m.Called() From b528a87913101859fcbf48cc05187cb59a61ffbb Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 7 Jan 2025 18:02:15 +0200 Subject: [PATCH 05/10] fix mocks errorS --- .../data_providers/account_statuses_provider_test.go | 3 +++ .../websockets/data_providers/block_digests_provider_test.go | 3 ++- .../websockets/data_providers/block_headers_provider_test.go | 3 ++- .../rest/websockets/data_providers/events_provider_test.go | 3 +++ engine/access/rest/websockets/data_providers/factory_test.go | 5 +++-- .../send_and_get_transaction_statuses_provider_test.go | 2 ++ .../data_providers/transaction_statuses_provider_test.go | 3 +++ 7 files changed, 18 insertions(+), 4 deletions(-) diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go index f482cfd8ff3..1de39ce7ee0 100644 --- a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -179,6 +180,7 @@ func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_InvalidAr ctx, s.log, s.api, + uuid.New(), topic, test.arguments, send, @@ -220,6 +222,7 @@ func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderRe ctx, s.log, s.api, + uuid.New(), topic, arguments, send, diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go index 975716c74af..846c1b3e181 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -43,7 +44,7 @@ func (s *BlockDigestsProviderSuite) TestBlockDigestsDataProvider_InvalidArgument for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, topic, test.arguments, send) + provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, uuid.New(), topic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go index b929a46d076..8a8cf7dbddb 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -43,7 +44,7 @@ func (s *BlockHeadersProviderSuite) TestBlockHeadersDataProvider_InvalidArgument for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, topic, test.arguments, send) + provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, uuid.New(), topic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/events_provider_test.go b/engine/access/rest/websockets/data_providers/events_provider_test.go index 6e09212efcf..e1cfbcd62d8 100644 --- a/engine/access/rest/websockets/data_providers/events_provider_test.go +++ b/engine/access/rest/websockets/data_providers/events_provider_test.go @@ -6,6 +6,7 @@ import ( "strconv" "testing" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -204,6 +205,7 @@ func (s *EventsProviderSuite) TestEventsDataProvider_InvalidArguments() { ctx, s.log, s.api, + uuid.New(), topic, test.arguments, send, @@ -245,6 +247,7 @@ func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() ctx, s.log, s.api, + uuid.New(), topic, arguments, send, diff --git a/engine/access/rest/websockets/data_providers/factory_test.go b/engine/access/rest/websockets/data_providers/factory_test.go index ca418e494c7..c6dd8fb930f 100644 --- a/engine/access/rest/websockets/data_providers/factory_test.go +++ b/engine/access/rest/websockets/data_providers/factory_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -161,7 +162,7 @@ func (s *DataProviderFactorySuite) TestSupportedTopics() { s.T().Parallel() test.setupSubscription() - provider, err := s.factory.NewDataProvider(s.ctx, test.topic, test.arguments, s.ch) + provider, err := s.factory.NewDataProvider(s.ctx, uuid.New(), test.topic, test.arguments, s.ch) s.Require().NotNil(provider, "Expected provider for topic %s", test.topic) s.Require().NoError(err, "Expected no error for topic %s", test.topic) s.Require().Equal(test.topic, provider.Topic()) @@ -183,7 +184,7 @@ func (s *DataProviderFactorySuite) TestUnsupportedTopics() { } for _, topic := range unsupportedTopics { - provider, err := s.factory.NewDataProvider(s.ctx, topic, nil, s.ch) + provider, err := s.factory.NewDataProvider(s.ctx, uuid.New(), topic, nil, s.ch) s.Require().Nil(provider, "Expected no provider for unsupported topic %s", topic) s.Require().Error(err, "Expected error for unsupported topic %s", topic) s.Require().EqualError(err, fmt.Sprintf("unsupported topic \"%s\"", topic)) diff --git a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go index 1776d7a873a..0691e3f9584 100644 --- a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -127,6 +128,7 @@ func (s *SendTransactionStatusesProviderSuite) TestSendTransactionStatusesDataPr ctx, s.log, s.api, + uuid.New(), topic, test.arguments, send, diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go index 00f446f7286..4c74275b098 100644 --- a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go @@ -6,6 +6,7 @@ import ( "strconv" "testing" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -156,6 +157,7 @@ func (s *TransactionStatusesProviderSuite) TestTransactionStatusesDataProvider_I ctx, s.log, s.api, + uuid.New(), topic, test.arguments, send, @@ -243,6 +245,7 @@ func (s *TransactionStatusesProviderSuite) TestMessageIndexTransactionStatusesPr ctx, s.log, s.api, + uuid.New(), topic, arguments, send, From 74a607619b09377ec18c20be4a679234ea954a09 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 14 Jan 2025 15:30:34 +0200 Subject: [PATCH 06/10] fix commets regarding models --- engine/access/rest/websockets/controller.go | 34 ++++++++----------- .../access/rest/websockets/controller_test.go | 13 ++++--- .../{utittest.go => unit_test.go} | 0 engine/access/rest/websockets/error_codes.go | 9 ----- .../rest/websockets/models/base_message.go | 4 +-- .../websockets/models/list_subscriptions.go | 4 +-- .../websockets/models/unsubscribe_message.go | 3 +- 7 files changed, 25 insertions(+), 42 deletions(-) rename engine/access/rest/websockets/data_providers/{utittest.go => unit_test.go} (100%) delete mode 100644 engine/access/rest/websockets/error_codes.go diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 8fe04dd8e87..a63943202d0 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -77,6 +77,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "sync" "time" @@ -310,7 +311,7 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error reading message", ""), + wrapErrorMessage(http.StatusBadRequest, "error reading message", ""), ) continue } @@ -320,7 +321,7 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error parsing message", ""), + wrapErrorMessage(http.StatusBadRequest, "error parsing message", ""), ) continue } @@ -369,7 +370,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error parsing subscription id", msg.SubscriptionID), + wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", msg.SubscriptionID), ) return } @@ -380,7 +381,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error creating data provider", subscriptionID.String()), + wrapErrorMessage(http.StatusBadRequest, "error creating data provider", subscriptionID.String()), ) return } @@ -402,7 +403,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(InternalServerError, "internal error", subscriptionID.String()), + wrapErrorMessage(http.StatusInternalServerError, "internal error", subscriptionID.String()), ) } @@ -417,7 +418,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error parsing subscription id", msg.SubscriptionID), + wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", msg.SubscriptionID), ) return } @@ -427,7 +428,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(NotFound, "subscription not found", subscriptionID.String()), + wrapErrorMessage(http.StatusNotFound, "subscription not found", subscriptionID.String()), ) return } @@ -443,7 +444,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeResponse(ctx, responseOk) } -func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { +func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListSubscriptionsMessageRequest) { var subs []*models.SubscriptionEntry err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { subs = append(subs, &models.SubscriptionEntry{ @@ -453,19 +454,12 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis return nil }) + // intentionally ignored, this never happens if err != nil { - c.writeErrorResponse( - ctx, - err, - wrapErrorMessage(NotFound, "error listing subscriptions", ""), - ) - return + c.logger.Debug().Err(err).Msg("error listing subscriptions") } responseOk := models.ListSubscriptionsMessageResponse{ - BaseMessageResponse: models.BaseMessageResponse{ - SubscriptionID: msg.SubscriptionID, - }, Subscriptions: subs, } c.writeResponse(ctx, responseOk) @@ -503,18 +497,18 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) { } } -func wrapErrorMessage(code Code, message string, subscriptionID string) models.BaseMessageResponse { +func wrapErrorMessage(code int, message string, subscriptionID string) models.BaseMessageResponse { return models.BaseMessageResponse{ SubscriptionID: subscriptionID, Error: models.ErrorMessage{ - Code: int(code), + Code: code, Message: message, }, } } func (c *Controller) parseOrCreateSubscriptionID(id string) (uuid.UUID, error) { - // if client didn't provide subscription id, we create one for him + // if client didn't provide subscription id, we create one for them if id == "" { return uuid.New(), nil } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index ace74b8c8e1..9774720dbea 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "testing" "time" @@ -144,7 +145,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) require.NotEmpty(t, response.Error) - require.Equal(t, int(InvalidMessage), response.Error.Code) + require.Equal(t, http.StatusBadRequest, response.Error.Code) return websocket.ErrCloseSent }) @@ -179,7 +180,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) require.NotEmpty(t, response.Error) - require.Equal(t, int(InvalidMessage), response.Error.Code) + require.Equal(t, http.StatusBadRequest, response.Error.Code) return websocket.ErrCloseSent }) @@ -224,7 +225,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) require.NotEmpty(t, response.Error) - require.Equal(t, int(InternalServerError), response.Error.Code) + require.Equal(t, http.StatusInternalServerError, response.Error.Code) return websocket.ErrCloseSent }) @@ -363,7 +364,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.True(t, ok) require.NotEmpty(t, response.Error) require.Equal(t, request.SubscriptionID, response.SubscriptionID) - require.Equal(t, int(InvalidMessage), response.Error.Code) + require.Equal(t, http.StatusBadRequest, response.Error.Code) return websocket.ErrCloseSent }). @@ -433,7 +434,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.Equal(t, request.SubscriptionID, response.SubscriptionID) require.NotEmpty(t, response.Error) - require.Equal(t, int(NotFound), response.Error.Code) + require.Equal(t, http.StatusNotFound, response.Error.Code) return websocket.ErrCloseSent }). @@ -504,8 +505,6 @@ func (s *WsControllerSuite) TestListSubscriptions() { response, ok := msg.(models.ListSubscriptionsMessageResponse) require.True(t, ok) - require.Empty(t, response.Error) - require.Empty(t, response.SubscriptionID) require.Equal(t, 1, len(response.Subscriptions)) require.Equal(t, subscriptionID, response.Subscriptions[0].SubscriptionID) require.Equal(t, topic, response.Subscriptions[0].Topic) diff --git a/engine/access/rest/websockets/data_providers/utittest.go b/engine/access/rest/websockets/data_providers/unit_test.go similarity index 100% rename from engine/access/rest/websockets/data_providers/utittest.go rename to engine/access/rest/websockets/data_providers/unit_test.go diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go deleted file mode 100644 index a35deea3bbc..00000000000 --- a/engine/access/rest/websockets/error_codes.go +++ /dev/null @@ -1,9 +0,0 @@ -package websockets - -type Code int - -const ( - InvalidMessage Code = 400 - NotFound Code = 404 - InternalServerError Code = 500 -) diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index 3b7e5e50745..dc26a2914b0 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -16,8 +16,8 @@ type BaseMessageRequest struct { // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { - SubscriptionID string `json:"subscription_id,omitempty"` // SubscriptionID might be empty in case of error response - Error ErrorMessage `json:"error,omitempty"` // Error might be empty in case of OK response + SubscriptionID string `json:"subscription_id"` // SubscriptionID might be empty in case of error response + Error ErrorMessage `json:"error,omitempty"` // Error might be empty in case of OK response } type ErrorMessage struct { diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go index ba4fcc9cb4b..185fc65bbe0 100644 --- a/engine/access/rest/websockets/models/list_subscriptions.go +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -8,6 +8,6 @@ type ListSubscriptionsMessageRequest struct { // ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. // It contains a list of active subscriptions for the current WebSocket connection. type ListSubscriptionsMessageResponse struct { - BaseMessageResponse - Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` // Subscriptions might be empty in case of no active subscriptions + // Subscription list might be empty in case of no active subscriptions + Subscriptions []*SubscriptionEntry `json:"subscriptions"` } diff --git a/engine/access/rest/websockets/models/unsubscribe_message.go b/engine/access/rest/websockets/models/unsubscribe_message.go index ca81fbb31a0..f72e6cb5c7b 100644 --- a/engine/access/rest/websockets/models/unsubscribe_message.go +++ b/engine/access/rest/websockets/models/unsubscribe_message.go @@ -2,9 +2,8 @@ package models // UnsubscribeMessageRequest represents a request to unsubscribe from a topic. type UnsubscribeMessageRequest struct { + // Note: subscription_id is mandatory for this request BaseMessageRequest - //TODO: in this request, subscription_id is mandatory, but we inherit the optional one. - // should we rewrite args to meet requirements? } // UnsubscribeMessageResponse represents the response to an unsubscription request. From 06951a911b5cb5b8f27a3c68db9715754ecf62af Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Wed, 15 Jan 2025 14:27:31 +0200 Subject: [PATCH 07/10] Relax rules for subscription ID Client is not required to provide valid UUID for subscription ID from now on. It can be any string [1-20] characters long. If no subscription ID provided, server creates UUID string for it. --- engine/access/rest/websockets/controller.go | 30 +++++------ .../access/rest/websockets/controller_test.go | 24 ++++----- .../account_statuses_provider.go | 3 +- .../account_statuses_provider_test.go | 5 +- .../data_providers/base_provider.go | 8 ++- .../data_providers/block_digests_provider.go | 3 +- .../block_digests_provider_test.go | 3 +- .../data_providers/block_headers_provider.go | 3 +- .../block_headers_provider_test.go | 3 +- .../data_providers/blocks_provider.go | 3 +- .../data_providers/blocks_provider_test.go | 3 +- .../data_providers/data_provider.go | 6 +-- .../data_providers/events_provider.go | 3 +- .../data_providers/events_provider_test.go | 5 +- .../rest/websockets/data_providers/factory.go | 11 +--- .../websockets/data_providers/factory_test.go | 5 +- .../data_providers/mock/data_provider.go | 15 ++---- .../mock/data_provider_factory.go | 10 ++-- ...d_and_get_transaction_statuses_provider.go | 3 +- ..._get_transaction_statuses_provider_test.go | 3 +- .../transaction_statuses_provider.go | 3 +- .../transaction_statuses_provider_test.go | 5 +- .../websockets/data_providers/unit_test.go | 3 +- .../access/rest/websockets/subscription_id.go | 54 +++++++++++++++++++ 24 files changed, 113 insertions(+), 101 deletions(-) create mode 100644 engine/access/rest/websockets/subscription_id.go diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 0c6010d2b26..efc1f532bce 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -83,7 +83,6 @@ import ( "golang.org/x/time/rate" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" @@ -130,7 +129,7 @@ type Controller struct { // issues such as sending on a closed channel while maintaining proper cleanup. multiplexedStream chan interface{} - dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] + dataProviders *concurrentmap.Map[SubscriptionID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory dataProvidersGroup *sync.WaitGroup limiter *rate.Limiter @@ -147,7 +146,7 @@ func NewWebSocketController( config: config, conn: conn, multiplexedStream: make(chan interface{}), - dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), + dataProviders: concurrentmap.New[SubscriptionID, dp.DataProvider](), dataProviderFactory: dataProviderFactory, dataProvidersGroup: &sync.WaitGroup{}, limiter: rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1), @@ -384,7 +383,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe } // register new provider - provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID, msg.Topic, msg.Arguments, c.multiplexedStream) + provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID.String(), msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { c.writeErrorResponse( ctx, @@ -421,7 +420,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe } func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { - subscriptionID, err := uuid.Parse(msg.SubscriptionID) + subscriptionID, err := ParseClientSubscriptionID(msg.SubscriptionID) if err != nil { c.writeErrorResponse( ctx, @@ -454,7 +453,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListSubscriptionsMessageRequest) { var subs []*models.SubscriptionEntry - err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { + err := c.dataProviders.ForEach(func(id SubscriptionID, provider dp.DataProvider) error { subs = append(subs, &models.SubscriptionEntry{ SubscriptionID: id.String(), Topic: provider.Topic(), @@ -479,7 +478,7 @@ func (c *Controller) shutdownConnection() { c.logger.Debug().Err(err).Msg("error closing connection") } - err = c.dataProviders.ForEach(func(_ uuid.UUID, provider dp.DataProvider) error { + err = c.dataProviders.ForEach(func(_ SubscriptionID, provider dp.DataProvider) error { provider.Close() return nil }) @@ -515,20 +514,15 @@ func wrapErrorMessage(code int, message string, subscriptionID string) models.Ba } } -func (c *Controller) parseOrCreateSubscriptionID(id string) (uuid.UUID, error) { - // if client didn't provide subscription id, we create one for them - if id == "" { - return uuid.New(), nil - } - - newID, err := uuid.Parse(id) +func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, error) { + newId, err := NewSubscriptionID(id) if err != nil { - return uuid.Nil, err + return SubscriptionID{}, err } - if c.dataProviders.Has(newID) { - return uuid.Nil, fmt.Errorf("subscription id is already in use") + if c.dataProviders.Has(newId) { + return SubscriptionID{}, fmt.Errorf("such subscription is already in use: %s", newId) } - return newID, nil + return newId, nil } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 422f6a6c1be..80524482597 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -69,7 +69,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - SubscriptionID: uuid.New().String(), + SubscriptionID: "dummy-id", Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, @@ -169,7 +169,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}) - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) conn. @@ -213,7 +213,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}) - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) @@ -263,7 +263,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) @@ -332,13 +332,13 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - SubscriptionID: "invalid-uuid", + SubscriptionID: uuid.New().String() + " .42", // invalid subscription ID Action: models.UnsubscribeAction, }, } @@ -401,13 +401,13 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - SubscriptionID: uuid.New().String(), // unknown subscription id + SubscriptionID: "unknown-sub-id", Action: models.UnsubscribeAction, }, } @@ -475,7 +475,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(nil). Once() - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) @@ -550,7 +550,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}) - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) @@ -606,7 +606,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}) - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) @@ -806,7 +806,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { Once() done := make(chan struct{}) - subscriptionID := uuid.New().String() + subscriptionID := "dummy-id" s.expectSubscribeRequest(t, conn, subscriptionID) s.expectSubscribeResponse(t, conn, subscriptionID) diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider.go b/engine/access/rest/websockets/data_providers/account_statuses_provider.go index 97c23f31719..9d799f220ee 100644 --- a/engine/access/rest/websockets/data_providers/account_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -42,7 +41,7 @@ func NewAccountStatusesDataProvider( ctx context.Context, logger zerolog.Logger, stateStreamApi state_stream.API, - subscriptionID uuid.UUID, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go index 87b177d33f4..7ff14ea8597 100644 --- a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -178,7 +177,7 @@ func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_InvalidAr ctx, s.log, s.api, - uuid.New(), + "dummy-id", topic, test.arguments, send, @@ -220,7 +219,7 @@ func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderRe ctx, s.log, s.api, - uuid.New(), + "dummy-id", topic, arguments, send, diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index 8936a883161..27b757dbf74 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -3,14 +3,12 @@ package data_providers import ( "context" - "github.com/google/uuid" - "github.com/onflow/flow-go/engine/access/subscription" ) // baseDataProvider holds common objects for the provider type baseDataProvider struct { - subscriptionID uuid.UUID + subscriptionID string topic string cancel context.CancelFunc send chan<- interface{} @@ -19,7 +17,7 @@ type baseDataProvider struct { // newBaseDataProvider creates a new instance of baseDataProvider. func newBaseDataProvider( - subscriptionID uuid.UUID, + subscriptionID string, topic string, cancel context.CancelFunc, send chan<- interface{}, @@ -35,7 +33,7 @@ func newBaseDataProvider( } // ID returns the subscription ID associated with current data provider -func (b *baseDataProvider) ID() uuid.UUID { +func (b *baseDataProvider) ID() string { return b.subscriptionID } diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider.go b/engine/access/rest/websockets/data_providers/block_digests_provider.go index a60dae0bbc4..12d46daf03f 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -29,7 +28,7 @@ func NewBlockDigestsDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, - subscriptionID uuid.UUID, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go index 6bc18ee0650..395e57dfa04 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go @@ -5,7 +5,6 @@ import ( "strconv" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -44,7 +43,7 @@ func (s *BlockDigestsProviderSuite) TestBlockDigestsDataProvider_InvalidArgument for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, uuid.New(), topic, test.arguments, send) + provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, "dummy-id", topic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider.go b/engine/access/rest/websockets/data_providers/block_headers_provider.go index 58a52ea13c4..d6b39d17082 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -30,7 +29,7 @@ func NewBlockHeadersDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, - subscriptionID uuid.UUID, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go index f2b844117eb..8834d21d498 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go @@ -5,7 +5,6 @@ import ( "strconv" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -45,7 +44,7 @@ func (s *BlockHeadersProviderSuite) TestBlockHeadersDataProvider_InvalidArgument for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, uuid.New(), topic, test.arguments, send) + provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, "dummy-id", topic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/blocks_provider.go b/engine/access/rest/websockets/data_providers/blocks_provider.go index 15635660409..00f9e6120dc 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -41,7 +40,7 @@ func NewBlocksDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, - subscriptionID uuid.UUID, + subscriptionID string, linkGenerator commonmodels.LinkGenerator, topic string, arguments models.Arguments, diff --git a/engine/access/rest/websockets/data_providers/blocks_provider_test.go b/engine/access/rest/websockets/data_providers/blocks_provider_test.go index d75e7fb32aa..2754572dff7 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider_test.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider_test.go @@ -6,7 +6,6 @@ import ( "strconv" "testing" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -132,7 +131,7 @@ func (s *BlocksProviderSuite) TestBlocksDataProvider_InvalidArguments() { for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlocksDataProvider(ctx, s.log, s.api, uuid.New(), nil, BlocksTopic, test.arguments, send) + provider, err := NewBlocksDataProvider(ctx, s.log, s.api, "dummy-id", nil, BlocksTopic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go index ed6c11b0f0d..acaa857ead2 100644 --- a/engine/access/rest/websockets/data_providers/data_provider.go +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -1,14 +1,10 @@ package data_providers -import ( - "github.com/google/uuid" -) - // The DataProvider is the interface abstracts of the actual data provider used by the WebSocketCollector. // It provides methods for retrieving the provider's unique SubscriptionID, topic, and a methods to close and run the provider. type DataProvider interface { // ID returns the unique identifier of the data provider. - ID() uuid.UUID + ID() string // Topic returns the topic associated with the data provider. Topic() string // Close terminates the data provider. diff --git a/engine/access/rest/websockets/data_providers/events_provider.go b/engine/access/rest/websockets/data_providers/events_provider.go index 42b4acc036f..22979ef4d16 100644 --- a/engine/access/rest/websockets/data_providers/events_provider.go +++ b/engine/access/rest/websockets/data_providers/events_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/common/parser" @@ -41,7 +40,7 @@ func NewEventsDataProvider( ctx context.Context, logger zerolog.Logger, stateStreamApi state_stream.API, - subscriptionID uuid.UUID, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, diff --git a/engine/access/rest/websockets/data_providers/events_provider_test.go b/engine/access/rest/websockets/data_providers/events_provider_test.go index 4089f23ea8c..4fbe4908ca8 100644 --- a/engine/access/rest/websockets/data_providers/events_provider_test.go +++ b/engine/access/rest/websockets/data_providers/events_provider_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -204,7 +203,7 @@ func (s *EventsProviderSuite) TestEventsDataProvider_InvalidArguments() { ctx, s.log, s.api, - uuid.New(), + "dummy-id", topic, test.arguments, send, @@ -246,7 +245,7 @@ func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() ctx, s.log, s.api, - uuid.New(), + "dummy-id", topic, arguments, send, diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index 38f5fda6da1..e0e5489d457 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/onflow/flow-go/access" @@ -34,7 +33,7 @@ type DataProviderFactory interface { // and configuration parameters. // // No errors are expected during normal operations. - NewDataProvider(ctx context.Context, subscriptionID uuid.UUID, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) + NewDataProvider(ctx context.Context, subscriptionID string, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) } var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) @@ -92,13 +91,7 @@ func NewDataProviderFactory( // - ch: Channel to which the data provider sends data. // // No errors are expected during normal operations. -func (s *DataProviderFactoryImpl) NewDataProvider( - ctx context.Context, - subscriptionID uuid.UUID, - topic string, - arguments models.Arguments, - ch chan<- interface{}, -) (DataProvider, error) { +func (s *DataProviderFactoryImpl) NewDataProvider(ctx context.Context, subscriptionID string, topic string, arguments models.Arguments, ch chan<- interface{}) (DataProvider, error) { switch topic { case BlocksTopic: return NewBlocksDataProvider(ctx, s.logger, s.accessApi, subscriptionID, s.linkGenerator, topic, arguments, ch) diff --git a/engine/access/rest/websockets/data_providers/factory_test.go b/engine/access/rest/websockets/data_providers/factory_test.go index bdc05afc7c0..c89a532fddb 100644 --- a/engine/access/rest/websockets/data_providers/factory_test.go +++ b/engine/access/rest/websockets/data_providers/factory_test.go @@ -5,7 +5,6 @@ import ( "fmt" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -161,7 +160,7 @@ func (s *DataProviderFactorySuite) TestSupportedTopics() { s.T().Parallel() test.setupSubscription() - provider, err := s.factory.NewDataProvider(s.ctx, uuid.New(), test.topic, test.arguments, s.ch) + provider, err := s.factory.NewDataProvider(s.ctx, "dummy-id", test.topic, test.arguments, s.ch) s.Require().NotNil(provider, "Expected provider for topic %s", test.topic) s.Require().NoError(err, "Expected no error for topic %s", test.topic) s.Require().Equal(test.topic, provider.Topic()) @@ -183,7 +182,7 @@ func (s *DataProviderFactorySuite) TestUnsupportedTopics() { } for _, topic := range unsupportedTopics { - provider, err := s.factory.NewDataProvider(s.ctx, uuid.New(), topic, nil, s.ch) + provider, err := s.factory.NewDataProvider(s.ctx, "dummy-id", topic, nil, s.ch) s.Require().Nil(provider, "Expected no provider for unsupported topic %s", topic) s.Require().Error(err, "Expected error for unsupported topic %s", topic) s.Require().EqualError(err, fmt.Sprintf("unsupported topic \"%s\"", topic)) diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go index 48debb23ae3..478f1625ad5 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -2,10 +2,7 @@ package mock -import ( - uuid "github.com/google/uuid" - mock "github.com/stretchr/testify/mock" -) +import mock "github.com/stretchr/testify/mock" // DataProvider is an autogenerated mock type for the DataProvider type type DataProvider struct { @@ -18,20 +15,18 @@ func (_m *DataProvider) Close() { } // ID provides a mock function with given fields: -func (_m *DataProvider) ID() uuid.UUID { +func (_m *DataProvider) ID() string { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for ID") } - var r0 uuid.UUID - if rf, ok := ret.Get(0).(func() uuid.UUID); ok { + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(uuid.UUID) - } + r0 = ret.Get(0).(string) } return r0 diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go index 7c7d4bc58c0..c18fcc5e56a 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go @@ -9,8 +9,6 @@ import ( mock "github.com/stretchr/testify/mock" models "github.com/onflow/flow-go/engine/access/rest/websockets/models" - - uuid "github.com/google/uuid" ) // DataProviderFactory is an autogenerated mock type for the DataProviderFactory type @@ -19,7 +17,7 @@ type DataProviderFactory struct { } // NewDataProvider provides a mock function with given fields: ctx, subscriptionID, topic, args, ch -func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, subscriptionID uuid.UUID, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { +func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, subscriptionID string, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { ret := _m.Called(ctx, subscriptionID, topic, args, ch) if len(ret) == 0 { @@ -28,10 +26,10 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, subscription var r0 data_providers.DataProvider var r1 error - if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { return rf(ctx, subscriptionID, topic, args, ch) } - if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { r0 = rf(ctx, subscriptionID, topic, args, ch) } else { if ret.Get(0) != nil { @@ -39,7 +37,7 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, subscription } } - if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID, string, models.Arguments, chan<- interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, string, models.Arguments, chan<- interface{}) error); ok { r1 = rf(ctx, subscriptionID, topic, args, ch) } else { r1 = ret.Error(1) diff --git a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go index 42838f2eff0..cf1e54919ec 100644 --- a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -40,7 +39,7 @@ func NewSendAndGetTransactionStatusesDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, - subscriptionID uuid.UUID, + subscriptionID string, linkGenerator commonmodels.LinkGenerator, topic string, arguments models.Arguments, diff --git a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go index 75932ebde11..0907f70c674 100644 --- a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -136,7 +135,7 @@ func (s *SendTransactionStatusesProviderSuite) TestSendTransactionStatusesDataPr ctx, s.log, s.api, - uuid.New(), + "dummy-id", s.linkGenerator, topic, test.arguments, diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go index 6d6c14d46b4..4bdf6a53359 100644 --- a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -43,7 +42,7 @@ func NewTransactionStatusesDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, - subscriptionID uuid.UUID, + subscriptionID string, linkGenerator commonmodels.LinkGenerator, topic string, arguments models.Arguments, diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go index fa05becd429..d28f6ae671c 100644 --- a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -188,7 +187,7 @@ func (s *TransactionStatusesProviderSuite) TestTransactionStatusesDataProvider_I ctx, s.log, s.api, - uuid.New(), + "dummy-id", s.linkGenerator, topic, test.arguments, @@ -283,7 +282,7 @@ func (s *TransactionStatusesProviderSuite) TestMessageIndexTransactionStatusesPr ctx, s.log, s.api, - uuid.New(), + "dummy-id", s.linkGenerator, topic, arguments, diff --git a/engine/access/rest/websockets/data_providers/unit_test.go b/engine/access/rest/websockets/data_providers/unit_test.go index b3ad732caa3..cbc75393db9 100644 --- a/engine/access/rest/websockets/data_providers/unit_test.go +++ b/engine/access/rest/websockets/data_providers/unit_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/onflow/flow-go/engine/access/rest/websockets/models" @@ -64,7 +63,7 @@ func testHappyPath( test.setupBackend(sub) // Create the data provider instance - provider, err := factory.NewDataProvider(ctx, uuid.New(), topic, test.arguments, send) + provider, err := factory.NewDataProvider(ctx, "dummy-id", topic, test.arguments, send) require.NotNil(t, provider) require.NoError(t, err) diff --git a/engine/access/rest/websockets/subscription_id.go b/engine/access/rest/websockets/subscription_id.go new file mode 100644 index 00000000000..09ffa5f7d5e --- /dev/null +++ b/engine/access/rest/websockets/subscription_id.go @@ -0,0 +1,54 @@ +package websockets + +import ( + "fmt" + + "github.com/google/uuid" +) + +const maxLen = 20 + +// SubscriptionID represents a subscription identifier used in websockets. +// The ID can either be provided by the client or generated by the server. +// - If provided by the client, it must adhere to specific restrictions. +// - If generated by the server, it is created as a UUID. +type SubscriptionID struct { + id string +} + +// NewSubscriptionID creates a new SubscriptionID based on the provided input. +// - If the input `id` is empty, a new UUID is generated and returned. +// - If the input `id` is non-empty, it is validated and returned if no errors. +func NewSubscriptionID(id string) (SubscriptionID, error) { + if len(id) == 0 { + return SubscriptionID{ + id: uuid.New().String(), + }, nil + } + + newID, err := ParseClientSubscriptionID(id) + if err != nil { + return SubscriptionID{}, err + } + + return newID, nil +} + +func ParseClientSubscriptionID(id string) (SubscriptionID, error) { + if len(id) == 0 { + return SubscriptionID{}, fmt.Errorf("subscription ID provided by the client must not be empty") + } + + if len(id) > maxLen { + return SubscriptionID{}, fmt.Errorf("subscription ID provided by the client must not exceed %d characters", maxLen) + } + + return SubscriptionID{ + id: id, + }, nil +} + +// String returns the string representation of the SubscriptionID. +func (id SubscriptionID) String() string { + return id.id +} From 37ac1f63619429cbc4c016b0fe4933b44d0ec744 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Wed, 15 Jan 2025 14:51:02 +0200 Subject: [PATCH 08/10] add field to response messages --- engine/access/rest/websockets/controller.go | 23 ++++++++++++------- .../access/rest/websockets/controller_test.go | 9 +++++++- .../rest/websockets/models/base_message.go | 1 + .../websockets/models/list_subscriptions.go | 1 + 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index efc1f532bce..5a4b8a9e6b6 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -318,7 +318,7 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusBadRequest, "error reading message", ""), + wrapErrorMessage(http.StatusBadRequest, "error reading message", "", ""), ) continue } @@ -328,7 +328,7 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusBadRequest, "error parsing message", ""), + wrapErrorMessage(http.StatusBadRequest, "error parsing message", "", ""), ) continue } @@ -377,7 +377,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", msg.SubscriptionID), + wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", + models.SubscribeAction, msg.SubscriptionID), ) return } @@ -388,7 +389,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusBadRequest, "error creating data provider", subscriptionID.String()), + wrapErrorMessage(http.StatusBadRequest, "error creating data provider", + models.SubscribeAction, subscriptionID.String()), ) return } @@ -410,7 +412,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusInternalServerError, "internal error", subscriptionID.String()), + wrapErrorMessage(http.StatusInternalServerError, "internal error", + models.SubscribeAction, subscriptionID.String()), ) } @@ -425,7 +428,8 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", msg.SubscriptionID), + wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", + models.UnsubscribeAction, msg.SubscriptionID), ) return } @@ -435,7 +439,8 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(http.StatusNotFound, "subscription not found", subscriptionID.String()), + wrapErrorMessage(http.StatusNotFound, "subscription not found", + models.UnsubscribeAction, subscriptionID.String()), ) return } @@ -468,6 +473,7 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListS responseOk := models.ListSubscriptionsMessageResponse{ Subscriptions: subs, + Action: models.ListSubscriptionsAction, } c.writeResponse(ctx, responseOk) } @@ -504,13 +510,14 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) { } } -func wrapErrorMessage(code int, message string, subscriptionID string) models.BaseMessageResponse { +func wrapErrorMessage(code int, message string, action string, subscriptionID string) models.BaseMessageResponse { return models.BaseMessageResponse{ SubscriptionID: subscriptionID, Error: models.ErrorMessage{ Code: code, Message: message, }, + Action: action, } } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 80524482597..a8870f4a171 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -109,7 +109,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { dataProvider.AssertExpectations(t) }) - s.T().Run("Parse and validate error", func(t *testing.T) { + s.T().Run("Validate message error", func(t *testing.T) { t.Parallel() conn, dataProviderFactory, _ := newControllerMocks(t) @@ -146,6 +146,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { require.True(t, ok) require.NotEmpty(t, response.Error) require.Equal(t, http.StatusBadRequest, response.Error.Code) + require.Equal(t, "", response.Action) return websocket.ErrCloseSent }) @@ -181,6 +182,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { require.True(t, ok) require.NotEmpty(t, response.Error) require.Equal(t, http.StatusBadRequest, response.Error.Code) + require.Equal(t, models.SubscribeAction, response.Action) return websocket.ErrCloseSent }) @@ -226,6 +228,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { require.True(t, ok) require.NotEmpty(t, response.Error) require.Equal(t, http.StatusInternalServerError, response.Error.Code) + require.Equal(t, models.SubscribeAction, response.Action) return websocket.ErrCloseSent }) @@ -365,6 +368,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.NotEmpty(t, response.Error) require.Equal(t, request.SubscriptionID, response.SubscriptionID) require.Equal(t, http.StatusBadRequest, response.Error.Code) + require.Equal(t, models.UnsubscribeAction, response.Action) return websocket.ErrCloseSent }). @@ -436,6 +440,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.NotEmpty(t, response.Error) require.Equal(t, http.StatusNotFound, response.Error.Code) + require.Equal(t, models.UnsubscribeAction, response.Action) + return websocket.ErrCloseSent }). Once() @@ -508,6 +514,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { require.Equal(t, 1, len(response.Subscriptions)) require.Equal(t, subscriptionID, response.Subscriptions[0].SubscriptionID) require.Equal(t, topic, response.Subscriptions[0].Topic) + require.Equal(t, models.ListSubscriptionsAction, response.Action) return websocket.ErrCloseSent }). diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index dc26a2914b0..09c10d3ef8c 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -18,6 +18,7 @@ type BaseMessageRequest struct { type BaseMessageResponse struct { SubscriptionID string `json:"subscription_id"` // SubscriptionID might be empty in case of error response Error ErrorMessage `json:"error,omitempty"` // Error might be empty in case of OK response + Action string `json:"action"` } type ErrorMessage struct { diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go index 185fc65bbe0..49c8edf5b96 100644 --- a/engine/access/rest/websockets/models/list_subscriptions.go +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -10,4 +10,5 @@ type ListSubscriptionsMessageRequest struct { type ListSubscriptionsMessageResponse struct { // Subscription list might be empty in case of no active subscriptions Subscriptions []*SubscriptionEntry `json:"subscriptions"` + Action string `json:"action"` } From cca9706be9a42bd6761f6edb17c01562cc0858dd Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 21 Jan 2025 12:57:23 +0200 Subject: [PATCH 09/10] add tests for subscription id struct --- engine/access/rest/websockets/controller.go | 12 +--- .../rest/websockets/data_providers/factory.go | 8 ++- .../rest/websockets/subscription_id_test.go | 60 +++++++++++++++++++ 3 files changed, 69 insertions(+), 11 deletions(-) create mode 100644 engine/access/rest/websockets/subscription_id_test.go diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 5a4b8a9e6b6..23668141f92 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -458,7 +458,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListSubscriptionsMessageRequest) { var subs []*models.SubscriptionEntry - err := c.dataProviders.ForEach(func(id SubscriptionID, provider dp.DataProvider) error { + _ = c.dataProviders.ForEach(func(id SubscriptionID, provider dp.DataProvider) error { subs = append(subs, &models.SubscriptionEntry{ SubscriptionID: id.String(), Topic: provider.Topic(), @@ -466,11 +466,6 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListS return nil }) - // intentionally ignored, this never happens - if err != nil { - c.logger.Debug().Err(err).Msg("error listing subscriptions") - } - responseOk := models.ListSubscriptionsMessageResponse{ Subscriptions: subs, Action: models.ListSubscriptionsAction, @@ -484,13 +479,10 @@ func (c *Controller) shutdownConnection() { c.logger.Debug().Err(err).Msg("error closing connection") } - err = c.dataProviders.ForEach(func(_ SubscriptionID, provider dp.DataProvider) error { + _ = c.dataProviders.ForEach(func(_ SubscriptionID, provider dp.DataProvider) error { provider.Close() return nil }) - if err != nil { - c.logger.Debug().Err(err).Msg("error closing data provider") - } c.dataProviders.Clear() c.dataProvidersGroup.Wait() diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index e0e5489d457..02d1a1320dd 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -33,7 +33,13 @@ type DataProviderFactory interface { // and configuration parameters. // // No errors are expected during normal operations. - NewDataProvider(ctx context.Context, subscriptionID string, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) + NewDataProvider( + ctx context.Context, + subscriptionID string, + topic string, + args models.Arguments, + ch chan<- interface{}, + ) (DataProvider, error) } var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) diff --git a/engine/access/rest/websockets/subscription_id_test.go b/engine/access/rest/websockets/subscription_id_test.go new file mode 100644 index 00000000000..bbe5c3f7f9a --- /dev/null +++ b/engine/access/rest/websockets/subscription_id_test.go @@ -0,0 +1,60 @@ +package websockets + +import ( + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewSubscriptionID(t *testing.T) { + t.Run("should generate new ID when input ID is empty", func(t *testing.T) { + subscriptionID, err := NewSubscriptionID("") + + assert.NoError(t, err) + assert.NotEmpty(t, subscriptionID.id) + assert.NoError(t, uuid.Validate(subscriptionID.id), "Generated ID should be a valid UUID") + }) + + t.Run("should return valid SubscriptionID when input ID is valid", func(t *testing.T) { + validID := "subscription/blocks" + subscriptionID, err := NewSubscriptionID(validID) + + assert.NoError(t, err) + assert.Equal(t, validID, subscriptionID.id) + }) + + t.Run("should return an error for invalid input in ParseClientSubscriptionID", func(t *testing.T) { + longID := fmt.Sprintf("%s%s", "id-", make([]byte, maxLen+1)) + _, err := NewSubscriptionID(longID) + + assert.Error(t, err) + assert.EqualError(t, err, fmt.Sprintf("subscription ID provided by the client must not exceed %d characters", maxLen)) + }) +} + +func TestParseClientSubscriptionID(t *testing.T) { + t.Run("should return error if input ID is empty", func(t *testing.T) { + _, err := ParseClientSubscriptionID("") + + assert.Error(t, err) + assert.EqualError(t, err, "subscription ID provided by the client must not be empty") + }) + + t.Run("should return error if input ID exceeds max length", func(t *testing.T) { + longID := fmt.Sprintf("%s%s", "id-", make([]byte, maxLen+1)) + _, err := ParseClientSubscriptionID(longID) + + assert.Error(t, err) + assert.EqualError(t, err, fmt.Sprintf("subscription ID provided by the client must not exceed %d characters", maxLen)) + }) + + t.Run("should return valid SubscriptionID for valid input", func(t *testing.T) { + validID := "subscription/blocks" + subscription, err := ParseClientSubscriptionID(validID) + + assert.NoError(t, err) + assert.Equal(t, validID, subscription.id) + }) +} From 99da7164e2a27f648ed5b82439db939c7005c16b Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 21 Jan 2025 12:58:20 +0200 Subject: [PATCH 10/10] change error message --- engine/access/rest/websockets/controller.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 23668141f92..7f064211eb2 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -520,7 +520,7 @@ func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, err } if c.dataProviders.Has(newId) { - return SubscriptionID{}, fmt.Errorf("such subscription is already in use: %s", newId) + return SubscriptionID{}, fmt.Errorf("subscription ID is already in use: %s", newId) } return newId, nil