Skip to content

Commit

Permalink
[DEC-2081] Audit app/prepare/prepare_proposal.go (#567)
Browse files Browse the repository at this point in the history
Mainly swap to use multi value return with error as its own type since this is idiomatic for Go.
Minor test improvements.
  • Loading branch information
lcwik authored Oct 12, 2023
1 parent 1d21554 commit 0f1bc2f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 61 deletions.
79 changes: 38 additions & 41 deletions protocol/app/prepare/prepare_proposal.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,24 @@ var (
// PricesTxResponse represents a response for creating `UpdateMarketPrices` tx.
type PricesTxResponse struct {
Tx []byte
Err error
NumMarkets int
}

// FundingTxResponse represents a response for creating `AddPremiumVotes` tx.
type FundingTxResponse struct {
Tx []byte
Err error
NumVotes int
}

// OperationTxResponse represents a response for creating 'ProposedOperations' tx
type OperationsTxResponse struct {
Tx []byte
Err error
NumOperations int
}

// BridgeTxResponse represents a response for creating 'AcknowledgeBridges' tx
type BridgeTxResponse struct {
Tx []byte
Err error
NumBridges int
}

Expand Down Expand Up @@ -76,9 +72,9 @@ func PrepareProposalHandler(
}

// Gather "FixedSize" group messages.
pricesTxResp := GetUpdateMarketPricesTx(ctx, txConfig, req.ProposerAddress, pricesKeeper)
if pricesTxResp.Err != nil {
ctx.Logger().Error(fmt.Sprintf("GetUpdateMarketPricesTx error: %v", pricesTxResp.Err))
pricesTxResp, err := GetUpdateMarketPricesTx(ctx, txConfig, pricesKeeper)
if err != nil {
ctx.Logger().Error(fmt.Sprintf("GetUpdateMarketPricesTx error: %v", err))
recordErrorMetricsWithLabel(metrics.PricesTx)
return EmptyResponse
}
Expand All @@ -89,9 +85,9 @@ func PrepareProposalHandler(
return EmptyResponse
}

fundingTxResp := GetAddPremiumVotesTx(ctx, txConfig, perpetualKeeper)
if fundingTxResp.Err != nil {
ctx.Logger().Error(fmt.Sprintf("GetAddPremiumVotesTx error: %v", fundingTxResp.Err))
fundingTxResp, err := GetAddPremiumVotesTx(ctx, txConfig, perpetualKeeper)
if err != nil {
ctx.Logger().Error(fmt.Sprintf("GetAddPremiumVotesTx error: %v", err))
recordErrorMetricsWithLabel(metrics.FundingTx)
return EmptyResponse
}
Expand All @@ -102,9 +98,9 @@ func PrepareProposalHandler(
return EmptyResponse
}

acknowledgeBridgesTxResp := GetAcknowledgeBridgesTx(ctx, txConfig, bridgeKeeper)
if acknowledgeBridgesTxResp.Err != nil {
ctx.Logger().Error(fmt.Sprintf("GetAcknowledgeBridgesTx error: %v", acknowledgeBridgesTxResp.Err))
acknowledgeBridgesTxResp, err := GetAcknowledgeBridgesTx(ctx, txConfig, bridgeKeeper)
if err != nil {
ctx.Logger().Error(fmt.Sprintf("GetAcknowledgeBridgesTx error: %v", err))
recordErrorMetricsWithLabel(metrics.AcknowledgeBridgesTx)
return EmptyResponse
}
Expand Down Expand Up @@ -133,9 +129,9 @@ func PrepareProposalHandler(

// Gather "OperationsRelated" group messages.
// TODO(DEC-1237): ensure ProposedOperations is within a certain size.
operationsTxResp := GetProposedOperationsTx(ctx, txConfig, clobKeeper)
if operationsTxResp.Err != nil {
ctx.Logger().Error(fmt.Sprintf("GetProposedOperationsTx error: %v", operationsTxResp.Err))
operationsTxResp, err := GetProposedOperationsTx(ctx, txConfig, clobKeeper)
if err != nil {
ctx.Logger().Error(fmt.Sprintf("GetProposedOperationsTx error: %v", err))
recordErrorMetricsWithLabel(metrics.OperationsTx)
return EmptyResponse
}
Expand All @@ -151,7 +147,7 @@ func PrepareProposalHandler(
if availableBytes > 0 && len(otherTxsRemainder) > 0 {
moreOtherTxsToInclude, _ := GetGroupMsgOther(otherTxsRemainder, availableBytes)
if len(moreOtherTxsToInclude) > 0 {
err := txs.AddOtherTxs(moreOtherTxsToInclude)
err = txs.AddOtherTxs(moreOtherTxsToInclude)
if err != nil {
ctx.Logger().Error(fmt.Sprintf("AddOtherTxs (additional) error: %v", err))
recordErrorMetricsWithLabel(metrics.OtherTxs)
Expand Down Expand Up @@ -185,102 +181,103 @@ func PrepareProposalHandler(
}

// GetUpdateMarketPricesTx returns a tx containing `MsgUpdateMarketPrices`.
// The response contains an error if encoding fails.
func GetUpdateMarketPricesTx(
ctx sdk.Context,
txConfig client.TxConfig,
proposerAddress []byte,
pricesKeeper PreparePricesKeeper,
) PricesTxResponse {
) (PricesTxResponse, error) {
// Get prices to update.
msgUpdateMarketPrices := pricesKeeper.GetValidMarketPriceUpdates(ctx)
if msgUpdateMarketPrices == nil {
return PricesTxResponse{Err: fmt.Errorf("MsgUpdateMarketPrices cannot be nil")}
return PricesTxResponse{}, fmt.Errorf("MsgUpdateMarketPrices cannot be nil")
}

tx, err := EncodeMsgsIntoTxBytes(txConfig, msgUpdateMarketPrices)
if err != nil {
return PricesTxResponse{Err: err}
return PricesTxResponse{}, err
}
if len(tx) == 0 {
return PricesTxResponse{Err: fmt.Errorf("Invalid tx: %v", tx)}
return PricesTxResponse{}, fmt.Errorf("Invalid tx: %v", tx)
}

return PricesTxResponse{
Tx: tx,
NumMarkets: len(msgUpdateMarketPrices.MarketPriceUpdates),
}
}, nil
}

// GetAddPremiumVotesTx returns a tx containing `MsgAddPremiumVotes`.
// The response contains an error if encoding fails.
func GetAddPremiumVotesTx(
ctx sdk.Context,
txConfig client.TxConfig,
perpetualsKeeper PreparePerpetualsKeeper,
) FundingTxResponse {
) (FundingTxResponse, error) {
// Get premium votes.
msgAddPremiumVotes := perpetualsKeeper.GetAddPremiumVotes(ctx)
if msgAddPremiumVotes == nil {
return FundingTxResponse{Err: fmt.Errorf("MsgAddPremiumVotes cannot be nil")}
return FundingTxResponse{}, fmt.Errorf("MsgAddPremiumVotes cannot be nil")
}

tx, err := EncodeMsgsIntoTxBytes(txConfig, msgAddPremiumVotes)
if err != nil {
return FundingTxResponse{Err: err}
return FundingTxResponse{}, err
}
if len(tx) == 0 {
return FundingTxResponse{Err: fmt.Errorf("Invalid tx: %v", tx)}
return FundingTxResponse{}, fmt.Errorf("Invalid tx: %v", tx)
}

return FundingTxResponse{
Tx: tx,
NumVotes: len(msgAddPremiumVotes.Votes),
}
}, nil
}

// GetProposedOperationsTx returns a tx containing `MsgProposedOperations`.
// The response contains an error if encoding fails.
func GetProposedOperationsTx(
ctx sdk.Context,
txConfig client.TxConfig,
clobKeeper PrepareClobKeeper,
) OperationsTxResponse {
) (OperationsTxResponse, error) {
// Get the order and fill messages from the CLOB keeper.
msgOperations := clobKeeper.GetOperations(ctx)
if msgOperations == nil {
return OperationsTxResponse{Err: fmt.Errorf("MsgProposedOperations cannot be nil")}
return OperationsTxResponse{}, fmt.Errorf("MsgProposedOperations cannot be nil")
}

tx, err := EncodeMsgsIntoTxBytes(txConfig, msgOperations)
if err != nil {
return OperationsTxResponse{Err: err}
return OperationsTxResponse{}, err
}
if len(tx) == 0 {
return OperationsTxResponse{Err: fmt.Errorf("Invalid tx: %v", tx)}
return OperationsTxResponse{}, fmt.Errorf("Invalid tx: %v", tx)
}

return OperationsTxResponse{Tx: tx, NumOperations: len(msgOperations.GetOperationsQueue())}
return OperationsTxResponse{
Tx: tx,
NumOperations: len(msgOperations.GetOperationsQueue()),
}, nil
}

// GetAcknowledgeBridgeTx returns a tx containing a list of `MsgAcknowledgeBridge`.
// The response contains an error if encoding fails.
func GetAcknowledgeBridgesTx(
ctx sdk.Context,
txConfig client.TxConfig,
bridgeKeeper PrepareBridgeKeeper,
) BridgeTxResponse {
) (BridgeTxResponse, error) {
msgAcknowledgeBridges := bridgeKeeper.GetAcknowledgeBridges(ctx, ctx.BlockTime())

tx, err := EncodeMsgsIntoTxBytes(txConfig, msgAcknowledgeBridges)
if err != nil {
return BridgeTxResponse{Err: err}
return BridgeTxResponse{}, err
}
if len(tx) == 0 {
return BridgeTxResponse{Err: fmt.Errorf("Invalid tx: %v", tx)}
return BridgeTxResponse{}, fmt.Errorf("Invalid tx: %v", tx)
}

return BridgeTxResponse{Tx: tx, NumBridges: len(msgAcknowledgeBridges.Events)}
return BridgeTxResponse{
Tx: tx,
NumBridges: len(msgAcknowledgeBridges.Events),
}, nil
}

// EncodeMsgsIntoTxBytes encodes the given msgs into a single transaction.
Expand Down
35 changes: 15 additions & 20 deletions protocol/app/prepare/prepare_proposal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import (
)

var (
ctx = sdktypes.Context{}
address = []byte{1, 2, 3, 4}
ctx = sdktypes.Context{}

failingTxEncoder = func(tx sdktypes.Tx) ([]byte, error) {
return nil, errors.New("encoder failed")
Expand All @@ -41,10 +40,6 @@ var (
}
)

func TestEmptyResponse(t *testing.T) {
require.Equal(t, abci.ResponsePrepareProposal{Txs: [][]byte{}}, prepare.EmptyResponse)
}

func TestPrepareProposalHandler(t *testing.T) {
msgSendTxBytesLen := int64(len(constants.Msg_Send_TxBytes))
msgSendAndTransferTxBytesLen := int64(len(constants.Msg_SendAndTransfer_TxBytes))
Expand Down Expand Up @@ -318,7 +313,7 @@ func TestPrepareProposalHandler(t *testing.T) {
expectedTxs: [][]byte{
{1, 2, 3, 4}, // order.
constants.Msg_Send_TxBytes, // others.
constants.Msg_SendAndTransfer_TxBytes, // addtional others.
constants.Msg_SendAndTransfer_TxBytes, // additional others.
{1, 2, 3, 4}, // bridge.
{1, 2, 3, 4}, // funding.
{1, 2, 3, 4}, // prices.
Expand Down Expand Up @@ -504,11 +499,11 @@ func TestGetUpdateMarketPricesTx(t *testing.T) {
mockPricesKeeper.On("GetValidMarketPriceUpdates", mock.Anything).
Return(tc.keeperResp)

resp := prepare.GetUpdateMarketPricesTx(ctx, mockTxConfig, address, &mockPricesKeeper)
resp, err := prepare.GetUpdateMarketPricesTx(ctx, mockTxConfig, &mockPricesKeeper)
if tc.expectedErr != nil {
require.ErrorContains(t, resp.Err, tc.expectedErr.Error())
require.Equal(t, err, tc.expectedErr)
} else {
require.NoError(t, resp.Err)
require.NoError(t, err)
}
require.Equal(t, tc.expectedTx, resp.Tx)
require.Equal(t, tc.expectedNumMarkets, resp.NumMarkets)
Expand Down Expand Up @@ -566,11 +561,11 @@ func TestGetAcknowledgeBridgesTx(t *testing.T) {
mockBridgeKeeper.On("GetAcknowledgeBridges", mock.Anything, mock.Anything).
Return(tc.keeperResp)

resp := prepare.GetAcknowledgeBridgesTx(ctx, mockTxConfig, &mockBridgeKeeper)
resp, err := prepare.GetAcknowledgeBridgesTx(ctx, mockTxConfig, &mockBridgeKeeper)
if tc.expectedErr != nil {
require.ErrorContains(t, resp.Err, tc.expectedErr.Error())
require.Equal(t, err, tc.expectedErr)
} else {
require.NoError(t, resp.Err)
require.NoError(t, err)
}
require.Equal(t, tc.expectedTx, resp.Tx)
require.Equal(t, tc.expectedNumBridges, resp.NumBridges)
Expand Down Expand Up @@ -628,11 +623,11 @@ func TestGetAddPremiumVotesTx(t *testing.T) {
mockPerpKeeper.On("GetAddPremiumVotes", mock.Anything).
Return(tc.keeperResp)

resp := prepare.GetAddPremiumVotesTx(ctx, mockTxConfig, &mockPerpKeeper)
resp, err := prepare.GetAddPremiumVotesTx(ctx, mockTxConfig, &mockPerpKeeper)
if tc.expectedErr != nil {
require.ErrorContains(t, resp.Err, tc.expectedErr.Error())
require.Equal(t, err, tc.expectedErr)
} else {
require.NoError(t, resp.Err)
require.NoError(t, err)
}
require.Equal(t, tc.expectedTx, resp.Tx)
require.Equal(t, tc.expectedNumVotes, resp.NumVotes)
Expand Down Expand Up @@ -692,11 +687,11 @@ func TestGetProposedOperationsTx(t *testing.T) {
mockClobKeeper := mocks.PrepareClobKeeper{}
mockClobKeeper.On("GetOperations", mock.Anything, mock.Anything).Return(tc.keeperResp)

resp := prepare.GetProposedOperationsTx(ctx, mockTxConfig, &mockClobKeeper)
resp, err := prepare.GetProposedOperationsTx(ctx, mockTxConfig, &mockClobKeeper)
if tc.expectedErr != nil {
require.ErrorContains(t, resp.Err, tc.expectedErr.Error())
require.Equal(t, err, tc.expectedErr)
} else {
require.NoError(t, resp.Err)
require.NoError(t, err)
}
require.Equal(t, tc.expectedTx, resp.Tx)
})
Expand Down Expand Up @@ -733,7 +728,7 @@ func TestEncodeMsgsIntoTxBytes(t *testing.T) {
tx, err := prepare.EncodeMsgsIntoTxBytes(mockTxConfig, &clobtypes.MsgProposedOperations{})

if tc.expectedErr != nil {
require.ErrorContains(t, err, tc.expectedErr.Error())
require.Equal(t, err, tc.expectedErr)
} else {
require.NoError(t, err)
}
Expand Down

0 comments on commit 0f1bc2f

Please sign in to comment.