Skip to content

Commit

Permalink
[DEC-2102] Audit lib/collections (#582)
Browse files Browse the repository at this point in the history
* SliceToSet -> UniqueSliceToSet

* imports
  • Loading branch information
BrendanChou authored Oct 12, 2023
1 parent cfa40b1 commit 6f5d9db
Show file tree
Hide file tree
Showing 17 changed files with 92 additions and 103 deletions.
7 changes: 4 additions & 3 deletions protocol/lib/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ func GetSortedKeys[R interface {
return keys
}

// SliceToSet converts a slice to a set. Function will panic if there are duplicate values.
func SliceToSet[K comparable](values []K) map[K]struct{} {
// UniqueSliceToSet converts a slice of unique values to a set.
// The function will panic if there are duplicate values.
func UniqueSliceToSet[K comparable](values []K) map[K]struct{} {
set := make(map[K]struct{}, len(values))
for _, sliceVal := range values {
if _, exists := set[sliceVal]; exists {
panic(
fmt.Sprintf(
"SliceToSet: duplicate value: %+v",
"UniqueSliceToSet: duplicate value: %+v",
sliceVal,
),
)
Expand Down
136 changes: 64 additions & 72 deletions protocol/lib/collections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@ import (
"testing"

"github.com/dydxprotocol/v4-chain/protocol/lib"
"github.com/dydxprotocol/v4-chain/protocol/x/clob/types"

"github.com/stretchr/testify/require"
)

func TestContainsDuplicates(t *testing.T) {
// Empty case.
require.False(t, lib.ContainsDuplicates([]types.OrderId{}))

// Unique uint32 case.
allUniqueUint32s := []uint32{1, 2, 3, 4}
require.False(t, lib.ContainsDuplicates(allUniqueUint32s))

// Duplicate uint32 case.
containsDuplicateUint32 := append(allUniqueUint32s, 3)
require.True(t, lib.ContainsDuplicates(containsDuplicateUint32))

// Unique string case.
allUniqueStrings := []string{"hello", "world", "h", "w"}
require.False(t, lib.ContainsDuplicates(allUniqueStrings))

// Duplicate string case.
containsDuplicateString := append(allUniqueStrings, "world")
require.True(t, lib.ContainsDuplicates(containsDuplicateString))
tests := map[string]struct {
input []uint32
expected bool
}{
"Nil": {
input: nil,
expected: false,
},
"Empty": {
input: []uint32{},
expected: false,
},
"True": {
input: []uint32{1, 2, 3, 4},
expected: false,
},
"False": {
input: []uint32{1, 2, 3, 4, 3},
expected: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.expected, lib.ContainsDuplicates(tc.input))
})
}
}

func TestGetSortedKeys(t *testing.T) {
Expand Down Expand Up @@ -59,6 +66,44 @@ func TestGetSortedKeys(t *testing.T) {
}
}

func TestUniqueSliceToSet(t *testing.T) {
tests := map[string]struct {
input []string
expected map[string]struct{}
panicWith string
}{
"Empty": {
input: []string{},
expected: map[string]struct{}{},
},
"Basic": {
input: []string{"0", "1", "2"},
expected: map[string]struct{}{
"0": {},
"1": {},
"2": {},
},
},
"Duplicate": {
input: []string{"one", "2", "two", "one", "4"},
panicWith: "UniqueSliceToSet: duplicate value: one",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if tc.panicWith != "" {
require.PanicsWithValue(
t,
tc.panicWith,
func() { lib.UniqueSliceToSet[string](tc.input) },
)
} else {
require.Equal(t, tc.expected, lib.UniqueSliceToSet[string](tc.input))
}
})
}
}

func TestMapSlice(t *testing.T) {
// Can increment all numbers in a slice by 1, and change type to `uint64`.
require.Equal(
Expand Down Expand Up @@ -171,59 +216,6 @@ func TestFilterSlice(t *testing.T) {
)
}

func TestSliceToSet(t *testing.T) {
slice := make([]int, 0)
for i := 0; i < 3; i++ {
slice = append(slice, i)
}
set := lib.SliceToSet(slice)
require.Equal(
t,
map[int]struct{}{
0: {},
1: {},
2: {},
},
set,
)
stringSlice := []string{
"one",
"two",
}
stringSet := lib.SliceToSet(stringSlice)
require.Equal(
t,
map[string]struct{}{
"one": {},
"two": {},
},
stringSet,
)

emptySlice := []types.OrderId{}
emptySet := lib.SliceToSet(emptySlice)
require.Equal(
t,
map[types.OrderId]struct{}{},
emptySet,
)
}

func TestSliceToSet_PanicOnDuplicate(t *testing.T) {
stringSlice := []string{
"one",
"two",
"one",
}
require.PanicsWithValue(
t,
"SliceToSet: duplicate value: one",
func() {
lib.SliceToSet(stringSlice)
},
)
}

func TestMergeAllMapsWithDistinctKeys(t *testing.T) {
tests := map[string]struct {
inputMaps []map[string]string
Expand Down
2 changes: 1 addition & 1 deletion protocol/x/blocktime/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewKeeper(
return &Keeper{
cdc: cdc,
storeKey: storeKey,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
2 changes: 1 addition & 1 deletion protocol/x/bridge/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewKeeper(
bridgeEventManager: bridgeEventManager,
bankKeeper: bankKeeper,
delayMsgKeeper: delayMsgKeeper,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
4 changes: 2 additions & 2 deletions protocol/x/clob/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ func EndBlocker(
keeper.AddUntriggeredConditionalOrders(
ctx,
processProposerMatchesEvents.PlacedConditionalOrderIds,
lib.SliceToSet(processProposerMatchesEvents.GetPlacedStatefulCancellationOrderIds()),
lib.SliceToSet(expiredStatefulOrderIds),
lib.UniqueSliceToSet(processProposerMatchesEvents.GetPlacedStatefulCancellationOrderIds()),
lib.UniqueSliceToSet(expiredStatefulOrderIds),
)

// Poll out all triggered conditional orders from `UntriggeredConditionalOrders` and update state.
Expand Down
2 changes: 1 addition & 1 deletion protocol/x/clob/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func NewKeeper(
storeKey: storeKey,
memKey: memKey,
transientStoreKey: liquidationsStoreKey,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
MemClob: memClob,
UntriggeredConditionalOrders: make(map[types.ClobPairId]*UntriggeredConditionalOrders),
PerpetualIdToClobPairId: make(map[uint32][]types.ClobPairId),
Expand Down
2 changes: 1 addition & 1 deletion protocol/x/clob/keeper/msg_server_cancel_orders.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (k msgServer) CancelOrder(
// TODO(CLOB-778): Prevent invalid MsgCancelOrder messages from being included in the block.
if errors.Is(err, types.ErrStatefulOrderDoesNotExist) {
processProposerMatchesEvents := k.Keeper.GetProcessProposerMatchesEvents(ctx)
removedOrderIds := lib.SliceToSet(processProposerMatchesEvents.RemovedStatefulOrderIds)
removedOrderIds := lib.UniqueSliceToSet(processProposerMatchesEvents.RemovedStatefulOrderIds)
if _, found := removedOrderIds[msg.GetOrderId()]; found {
telemetry.IncrCounterWithLabels(
[]string{
Expand Down
4 changes: 2 additions & 2 deletions protocol/x/clob/keeper/msg_server_place_order.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ func (k msgServer) PlaceOrder(goCtx context.Context, msg *types.MsgPlaceOrder) (

// 2. Return an error if an associated cancellation or removal already exists in the current block.
processProposerMatchesEvents := k.Keeper.GetProcessProposerMatchesEvents(ctx)
cancelledOrderIds := lib.SliceToSet(processProposerMatchesEvents.PlacedStatefulCancellationOrderIds)
cancelledOrderIds := lib.UniqueSliceToSet(processProposerMatchesEvents.PlacedStatefulCancellationOrderIds)
if _, found := cancelledOrderIds[order.GetOrderId()]; found {
return nil, errorsmod.Wrapf(
types.ErrStatefulOrderPreviouslyCancelled,
"PlaceOrder: order (%+v)",
order,
)
}
removedOrderIds := lib.SliceToSet(processProposerMatchesEvents.RemovedStatefulOrderIds)
removedOrderIds := lib.UniqueSliceToSet(processProposerMatchesEvents.RemovedStatefulOrderIds)
if _, found := removedOrderIds[order.GetOrderId()]; found {
return nil, errorsmod.Wrapf(
types.ErrStatefulOrderPreviouslyRemoved,
Expand Down
11 changes: 5 additions & 6 deletions protocol/x/clob/keeper/untriggered_conditional_orders.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ package keeper

import (
"fmt"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
"math/big"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/lib"
"github.com/dydxprotocol/v4-chain/protocol/x/clob/types"

indexerevents "github.com/dydxprotocol/v4-chain/protocol/indexer/events"
"github.com/dydxprotocol/v4-chain/protocol/indexer/indexer_manager"
"github.com/dydxprotocol/v4-chain/protocol/lib"
"github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
)

// UntriggeredConditionalOrders is an in-memory struct stored on the clob Keeper.
Expand Down Expand Up @@ -130,7 +129,7 @@ func (k Keeper) PruneUntriggeredConditionalOrders(
cancelledStatefulOrderIds []types.OrderId,
) {
// Merge lists of order ids.
orderIdsToPrune := lib.SliceToSet(expiredStatefulOrderIds)
orderIdsToPrune := lib.UniqueSliceToSet(expiredStatefulOrderIds)
for _, orderId := range cancelledStatefulOrderIds {
if _, exists := orderIdsToPrune[orderId]; exists {
panic(
Expand Down Expand Up @@ -198,7 +197,7 @@ func (untriggeredOrders *UntriggeredConditionalOrders) RemoveUntriggeredConditio
}
}

orderIdsToRemoveSet := lib.SliceToSet(orderIdsToRemove)
orderIdsToRemoveSet := lib.UniqueSliceToSet(orderIdsToRemove)

newOrdersToTriggerWhenOraclePriceLTETriggerPrice := make([]types.Order, 0)
for _, order := range untriggeredOrders.OrdersToTriggerWhenOraclePriceLTETriggerPrice {
Expand Down
2 changes: 1 addition & 1 deletion protocol/x/delaymsg/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func NewKeeper(
return &Keeper{
cdc: cdc,
storeKey: storeKey,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
router: router,
}
}
Expand Down
2 changes: 1 addition & 1 deletion protocol/x/feetiers/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewKeeper(
cdc: cdc,
statsKeeper: statsKeeper,
storeKey: storeKey,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
2 changes: 1 addition & 1 deletion protocol/x/perpetuals/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewKeeper(
pricesKeeper: pricesKeeper,
epochsKeeper: epochsKeeper,
indexerEventManager: indexerEventsManager,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
2 changes: 1 addition & 1 deletion protocol/x/prices/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewKeeper(
timeProvider: timeProvider,
indexerEventManager: indexerEventManager,
marketToCreatedAt: map[uint32]time.Time{},
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
11 changes: 4 additions & 7 deletions protocol/x/rewards/keeper/keeper.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
package keeper

import (
errorsmod "cosmossdk.io/errors"
"fmt"
"math/big"
"time"

"github.com/dydxprotocol/v4-chain/protocol/daemons/pricefeed/client/constants"

sdkmath "cosmossdk.io/math"

errorsmod "cosmossdk.io/errors"
sdklog "cosmossdk.io/log"
sdkmath "cosmossdk.io/math"
"github.com/cometbft/cometbft/libs/log"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/store/prefix"
storetypes "github.com/cosmos/cosmos-sdk/store/types"
"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"

"github.com/dydxprotocol/v4-chain/protocol/daemons/pricefeed/client/constants"
"github.com/dydxprotocol/v4-chain/protocol/dtypes"
"github.com/dydxprotocol/v4-chain/protocol/lib"
"github.com/dydxprotocol/v4-chain/protocol/lib/metrics"
Expand Down Expand Up @@ -65,7 +62,7 @@ func NewKeeper(
bankKeeper: bankKeeper,
feeTiersKeeper: feeTiersKeeper,
pricesKeeper: pricesKeeper,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
2 changes: 1 addition & 1 deletion protocol/x/sending/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewKeeper(
bankKeeper: bankKeeper,
subaccountsKeeper: subaccountsKeeper,
indexerEventManager: indexerEventManager,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
2 changes: 1 addition & 1 deletion protocol/x/stats/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewKeeper(
epochsKeeper: epochsKeeper,
storeKey: storeKey,
transientStoreKey: transientStoreKey,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down
2 changes: 1 addition & 1 deletion protocol/x/vest/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func NewKeeper(
storeKey: storeKey,
bankKeeper: bankKeeper,
blockTimeKeeper: blockTimeKeeper,
authorities: lib.SliceToSet(authorities),
authorities: lib.UniqueSliceToSet(authorities),
}
}

Expand Down

0 comments on commit 6f5d9db

Please sign in to comment.