Skip to content

Commit

Permalink
refactor: shared query client
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanchriswhite committed Dec 11, 2024
1 parent c4685f7 commit a19f686
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 48 deletions.
4 changes: 2 additions & 2 deletions pkg/client/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ type SessionQueryClient interface {
// SharedQueryClient defines an interface that enables the querying of the
// on-chain shared module params.
type SharedQueryClient interface {
// GetParams queries the chain for the current shared module parameters.
GetParams(ctx context.Context) (*sharedtypes.Params, error)
ParamsQuerier[*sharedtypes.Params]

// GetSessionGracePeriodEndHeight returns the block height at which the grace period
// for the session that includes queryHeight elapses.
// The grace period is the number of blocks after the session ends during which relays
Expand Down
66 changes: 40 additions & 26 deletions pkg/client/query/sharedquerier.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,54 @@ import (
var _ client.SharedQueryClient = (*sharedQuerier)(nil)

// sharedQuerier is a wrapper around the sharedtypes.QueryClient that enables the
// querying of on-chain shared information through a single exposed method
// which returns an sharedtypes.Session struct
// querying of on-chain shared information
type sharedQuerier struct {
client.ParamsQuerier[*sharedtypes.Params]

clientConn grpc.ClientConn
sharedQuerier sharedtypes.QueryClient
blockQuerier client.BlockQueryClient
}

// NewSharedQuerier returns a new instance of a client.SharedQueryClient by
// injecting the dependecies provided by the depinject.Config.
// injecting the dependencies provided by the depinject.Config.
//
// Required dependencies:
// - clientCtx (grpc.ClientConn)
// - client.BlockQueryClient
func NewSharedQuerier(deps depinject.Config) (client.SharedQueryClient, error) {
querier := &sharedQuerier{}
func NewSharedQuerier(
deps depinject.Config,
paramsQuerierOpts ...ParamsQuerierOptionFn,
) (client.SharedQueryClient, error) {
paramsQuerierCfg := DefaultParamsQuerierConfig()
for _, opt := range paramsQuerierOpts {
opt(paramsQuerierCfg)
}

paramsQuerier, err := NewCachedParamsQuerier[*sharedtypes.Params, sharedtypes.SharedQueryClient](
deps, sharedtypes.NewSharedQueryClient,
WithModuleInfo[*sharedtypes.Params](sharedtypes.ModuleName, sharedtypes.ErrSharedParamInvalid),
WithParamsCacheOptions(paramsQuerierCfg.CacheOpts...),
)
if err != nil {
return nil, err
}

if err := depinject.Inject(
sq := &sharedQuerier{
ParamsQuerier: paramsQuerier,
}

if err = depinject.Inject(
deps,
&querier.clientConn,
&querier.blockQuerier,
&sq.clientConn,
&sq.blockQuerier,
); err != nil {
return nil, err
}

querier.sharedQuerier = sharedtypes.NewQueryClient(querier.clientConn)
sq.sharedQuerier = sharedtypes.NewQueryClient(sq.clientConn)

return querier, nil
}

// GetParams queries & returns the shared module on-chain parameters.
//
// TODO_TECHDEBT(#543): We don't really want to have to query the params for every method call.
// Once `ModuleParamsClient` is implemented, use its replay observable's `#Last()` method
// to get the most recently (asynchronously) observed (and cached) value.
func (sq *sharedQuerier) GetParams(ctx context.Context) (*sharedtypes.Params, error) {
req := &sharedtypes.QueryParamsRequest{}
res, err := sq.sharedQuerier.Params(ctx, req)
if err != nil {
return nil, ErrQuerySessionParams.Wrapf("[%v]", err)
}
return &res.Params, nil
return sq, nil
}

// GetClaimWindowOpenHeight returns the block height at which the claim window of
Expand Down Expand Up @@ -118,7 +124,11 @@ func (sq *sharedQuerier) GetSessionGracePeriodEndHeight(
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params.
// Instead, we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error) {
func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(
ctx context.Context,
queryHeight int64,
supplierOperatorAddr string,
) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
if err != nil {
return 0, err
Expand Down Expand Up @@ -151,7 +161,11 @@ func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Contex
// to get the most recently (asynchronously) observed (and cached) value.
// TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params.
// Instead, we should be using the value that the params had for the session which includes queryHeight.
func (sq *sharedQuerier) GetEarliestSupplierProofCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error) {
func (sq *sharedQuerier) GetEarliestSupplierProofCommitHeight(
ctx context.Context,
queryHeight int64,
supplierOperatorAddr string,
) (int64, error) {
sharedParams, err := sq.GetParams(ctx)
if err != nil {
return 0, err
Expand Down
122 changes: 122 additions & 0 deletions pkg/client/query/sharedquerier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package query_test

import (
"context"
"testing"
"time"

"cosmossdk.io/depinject"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"

"github.com/pokt-network/poktroll/pkg/client"
"github.com/pokt-network/poktroll/pkg/client/query"

Check failure on line 15 in pkg/client/query/sharedquerier_test.go

View workflow job for this annotation

GitHub Actions / go-test

could not import github.com/pokt-network/poktroll/pkg/client/query (-: # github.com/pokt-network/poktroll/pkg/client/query
"github.com/pokt-network/poktroll/pkg/client/query/cache"
_ "github.com/pokt-network/poktroll/pkg/polylog/polyzero"
"github.com/pokt-network/poktroll/testutil/mockclient"
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
)

type SharedQuerierTestSuite struct {
suite.Suite
ctrl *gomock.Controller
ctx context.Context
querier client.SharedQueryClient
mockConn *mockclient.MockClientConn
mockBlock *mockclient.MockCometRPC
TTL time.Duration
}

func TestSharedQuerierSuite(t *testing.T) {
suite.Run(t, new(SharedQuerierTestSuite))
}

func (s *SharedQuerierTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.ctx = context.Background()
s.mockConn = mockclient.NewMockClientConn(s.ctrl)
s.mockBlock = mockclient.NewMockCometRPC(s.ctrl)
s.TTL = 200 * time.Millisecond

deps := depinject.Supply(s.mockConn, s.mockBlock)

// Create querier with test-specific cache settings
querier, err := query.NewSharedQuerier(deps,
query.WithCacheOptions(
cache.WithTTL(s.TTL),
cache.WithHistoricalMode(100),
),
)
require.NoError(s.T(), err)
require.NotNil(s.T(), querier)

s.querier = querier
}

func (s *SharedQuerierTestSuite) TearDownTest() {
s.ctrl.Finish()
}

func (s *SharedQuerierTestSuite) TestRetrievesAndCachesParamsValues() {
multiplier := uint64(1000)

// First query - params with multiplier 1000
s.expectMockConnToReturnParamsWithMultiplierOnce(multiplier)

// Initial query should fetch from chain.
params1, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(multiplier, params1.ComputeUnitsToTokensMultiplier)

// Second query - should use cache, no mock expectation needed, this is
// asserted here due to the mock expectation calling Times(1).
params2, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(multiplier, params2.ComputeUnitsToTokensMultiplier)

// Third query after 90% of the TTL - should still use cache.
time.Sleep(time.Duration(float64(s.TTL) * .9))
params3, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(multiplier, params3.ComputeUnitsToTokensMultiplier)
}

func (s *SharedQuerierTestSuite) TestHandlesCacheExpiration() {
// First query
s.expectMockConnToReturnParamsWithMultiplierOnce(2000)

params1, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(uint64(2000), params1.ComputeUnitsToTokensMultiplier)

// Wait for cache to expire
time.Sleep(300 * time.Millisecond)

// Next query should hit the chain again
s.expectMockConnToReturnParamsWithMultiplierOnce(3000)

params2, err := s.querier.GetParams(s.ctx)
s.NoError(err)
s.Equal(uint64(3000), params2.ComputeUnitsToTokensMultiplier)
}

func (s *SharedQuerierTestSuite) expectMockConnToReturnParamsWithMultiplierOnce(multiplier uint64) {
s.mockConn.EXPECT().
Invoke(
gomock.Any(),
"/poktroll.shared.Query/Params",
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
DoAndReturn(func(_ context.Context, _ string, _, reply any, _ ...grpc.CallOption) error {
resp := reply.(*sharedtypes.QueryParamsResponse)
params := sharedtypes.DefaultParams()
params.ComputeUnitsToTokensMultiplier = multiplier

resp.Params = params
return nil
}).Times(1)
}
58 changes: 38 additions & 20 deletions x/proof/types/shared_query_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/pokt-network/poktroll/pkg/client"
"github.com/pokt-network/poktroll/x/shared/keeper"
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
)

Expand All @@ -13,6 +14,8 @@ var _ client.SharedQueryClient = (*SharedKeeperQueryClient)(nil)
// It does not rely on the QueryClient, and therefore does not make any
// network requests as in the off-chain implementation.
type SharedKeeperQueryClient struct {
*keeper.KeeperParamsQuerier[sharedtypes.Params, SharedKeeper]

sharedKeeper SharedKeeper
sessionKeeper SessionKeeper
}
Expand All @@ -23,20 +26,15 @@ func NewSharedKeeperQueryClient(
sharedKeeper SharedKeeper,
sessionKeeper SessionKeeper,
) client.SharedQueryClient {
keeperParamsQuerier := keeper.NewKeeperParamsQuerier[sharedtypes.Params](sharedKeeper)

return &SharedKeeperQueryClient{
sharedKeeper: sharedKeeper,
sessionKeeper: sessionKeeper,
KeeperParamsQuerier: keeperParamsQuerier,
sharedKeeper: sharedKeeper,
sessionKeeper: sessionKeeper,
}
}

// GetParams queries & returns the shared module on-chain parameters.
func (sqc *SharedKeeperQueryClient) GetParams(
ctx context.Context,
) (params *sharedtypes.Params, err error) {
sharedParams := sqc.sharedKeeper.GetParams(ctx)
return &sharedParams, nil
}

// GetSessionGracePeriodEndHeight returns the block height at which the grace period
// for the session which includes queryHeight elapses.
// The grace period is the number of blocks after the session ends during which relays
Expand All @@ -48,8 +46,12 @@ func (sqc *SharedKeeperQueryClient) GetSessionGracePeriodEndHeight(
ctx context.Context,
queryHeight int64,
) (int64, error) {
sharedParams := sqc.sharedKeeper.GetParams(ctx)
return sharedtypes.GetSessionGracePeriodEndHeight(&sharedParams, queryHeight), nil
sharedParams, err := sqc.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}

return sharedtypes.GetSessionGracePeriodEndHeight(sharedParams, queryHeight), nil
}

// GetClaimWindowOpenHeight returns the block height at which the claim window of
Expand All @@ -61,8 +63,12 @@ func (sqc *SharedKeeperQueryClient) GetClaimWindowOpenHeight(
ctx context.Context,
queryHeight int64,
) (int64, error) {
sharedParams := sqc.sharedKeeper.GetParams(ctx)
return sharedtypes.GetClaimWindowOpenHeight(&sharedParams, queryHeight), nil
sharedParams, err := sqc.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}

return sharedtypes.GetClaimWindowOpenHeight(sharedParams, queryHeight), nil
}

// GetProofWindowOpenHeight returns the block height at which the proof window of
Expand All @@ -74,8 +80,12 @@ func (sqc *SharedKeeperQueryClient) GetProofWindowOpenHeight(
ctx context.Context,
queryHeight int64,
) (int64, error) {
sharedParams := sqc.sharedKeeper.GetParams(ctx)
return sharedtypes.GetProofWindowOpenHeight(&sharedParams, queryHeight), nil
sharedParams, err := sqc.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}

return sharedtypes.GetProofWindowOpenHeight(sharedParams, queryHeight), nil
}

// GetEarliestSupplierClaimCommitHeight returns the earliest block height at which a claim
Expand Down Expand Up @@ -109,8 +119,12 @@ func (sqc *SharedKeeperQueryClient) GetEarliestSupplierProofCommitHeight(
queryHeight int64,
supplierOperatorAddr string,
) (int64, error) {
sharedParams := sqc.sharedKeeper.GetParams(ctx)
proofWindowOpenHeight := sharedtypes.GetProofWindowOpenHeight(&sharedParams, queryHeight)
sharedParams, err := sqc.GetParamsAtHeight(ctx, queryHeight)
if err != nil {
return 0, err
}

proofWindowOpenHeight := sharedtypes.GetProofWindowOpenHeight(sharedParams, queryHeight)

// Fetch the proof window open block hash so that it can be used as part of the
// pseudo-random seed for generating the proof distribution offset.
Expand All @@ -119,7 +133,7 @@ func (sqc *SharedKeeperQueryClient) GetEarliestSupplierProofCommitHeight(

// Get the earliest proof commit height for the given supplier.
return sharedtypes.GetEarliestSupplierProofCommitHeight(
&sharedParams,
sharedParams,
queryHeight,
proofWindowOpenBlockHash,
supplierOperatorAddr,
Expand All @@ -133,6 +147,10 @@ func (sqc *SharedKeeperQueryClient) GetEarliestSupplierProofCommitHeight(
// Since this will be a non-frequent occurrence, accounting for this edge case is
// not an immediate blocker.
func (sqc *SharedKeeperQueryClient) GetComputeUnitsToTokensMultiplier(ctx context.Context) (uint64, error) {
sharedParams := sqc.sharedKeeper.GetParams(ctx)
sharedParams, err := sqc.GetParamsAtHeight(ctx, 0)
if err != nil {
return 0, err
}

return sharedParams.GetComputeUnitsToTokensMultiplier(), nil
}
Loading

0 comments on commit a19f686

Please sign in to comment.