Skip to content

Commit

Permalink
chore: clean up elcontracts to take context (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
shrimalmadhur authored Oct 23, 2024
1 parent 0042b1a commit 565bb44
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 49 deletions.
6 changes: 3 additions & 3 deletions chainio/clients/avsregistry/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

type eLReader interface {
CalculateOperatorAVSRegistrationDigestHash(
opts *bind.CallOpts,
ctx context.Context,
operatorAddr gethcommon.Address,
serviceManagerAddr gethcommon.Address,
operatorToAvsRegistrationSigSalt [32]byte,
Expand Down Expand Up @@ -233,7 +233,7 @@ func (w *ChainWriter) RegisterOperatorInQuorumWithAVSRegistryCoordinator(

// params to register operator in delegation manager's operator-avs mapping
msgToSign, err := w.elReader.CalculateOperatorAVSRegistrationDigestHash(
&bind.CallOpts{},
ctx,
operatorAddr,
w.serviceManagerAddr,
operatorToAvsRegistrationSigSalt,
Expand Down Expand Up @@ -355,7 +355,7 @@ func (w *ChainWriter) RegisterOperator(

// params to register operator in delegation manager's operator-avs mapping
msgToSign, err := w.elReader.CalculateOperatorAVSRegistrationDigestHash(
&bind.CallOpts{},
ctx,
operatorAddr,
w.serviceManagerAddr,
operatorToAvsRegistrationSigSalt,
Expand Down
94 changes: 63 additions & 31 deletions chainio/clients/elcontracts/reader.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package elcontracts

import (
"context"
"errors"

"math/big"
Expand Down Expand Up @@ -112,13 +113,16 @@ func NewReaderFromConfig(
), nil
}

func (r *ChainReader) IsOperatorRegistered(opts *bind.CallOpts, operator types.Operator) (bool, error) {
func (r *ChainReader) IsOperatorRegistered(
ctx context.Context,
operator types.Operator,
) (bool, error) {
if r.delegationManager == nil {
return false, errors.New("DelegationManager contract not provided")
}

isOperator, err := r.delegationManager.IsOperator(
opts,
&bind.CallOpts{Context: ctx},
gethcommon.HexToAddress(operator.Address),
)
if err != nil {
Expand All @@ -128,13 +132,16 @@ func (r *ChainReader) IsOperatorRegistered(opts *bind.CallOpts, operator types.O
return isOperator, nil
}

func (r *ChainReader) GetOperatorDetails(opts *bind.CallOpts, operator types.Operator) (types.Operator, error) {
func (r *ChainReader) GetOperatorDetails(
ctx context.Context,
operator types.Operator,
) (types.Operator, error) {
if r.delegationManager == nil {
return types.Operator{}, errors.New("DelegationManager contract not provided")
}

operatorDetails, err := r.delegationManager.OperatorDetails(
opts,
&bind.CallOpts{Context: ctx},
gethcommon.HexToAddress(operator.Address),
)
if err != nil {
Expand All @@ -150,13 +157,14 @@ func (r *ChainReader) GetOperatorDetails(opts *bind.CallOpts, operator types.Ope

// GetStrategyAndUnderlyingToken returns the strategy contract and the underlying token address
func (r *ChainReader) GetStrategyAndUnderlyingToken(
opts *bind.CallOpts, strategyAddr gethcommon.Address,
ctx context.Context,
strategyAddr gethcommon.Address,
) (*strategy.ContractIStrategy, gethcommon.Address, error) {
contractStrategy, err := strategy.NewContractIStrategy(strategyAddr, r.ethClient)
if err != nil {
return nil, common.Address{}, utils.WrapError("Failed to fetch strategy contract", err)
}
underlyingTokenAddr, err := contractStrategy.UnderlyingToken(opts)
underlyingTokenAddr, err := contractStrategy.UnderlyingToken(&bind.CallOpts{Context: ctx})
if err != nil {
return nil, common.Address{}, utils.WrapError("Failed to fetch token contract", err)
}
Expand All @@ -166,13 +174,14 @@ func (r *ChainReader) GetStrategyAndUnderlyingToken(
// GetStrategyAndUnderlyingERC20Token returns the strategy contract, the erc20 bindings for the underlying token
// and the underlying token address
func (r *ChainReader) GetStrategyAndUnderlyingERC20Token(
opts *bind.CallOpts, strategyAddr gethcommon.Address,
ctx context.Context,
strategyAddr gethcommon.Address,
) (*strategy.ContractIStrategy, erc20.ContractIERC20Methods, gethcommon.Address, error) {
contractStrategy, err := strategy.NewContractIStrategy(strategyAddr, r.ethClient)
if err != nil {
return nil, nil, common.Address{}, utils.WrapError("Failed to fetch strategy contract", err)
}
underlyingTokenAddr, err := contractStrategy.UnderlyingToken(opts)
underlyingTokenAddr, err := contractStrategy.UnderlyingToken(&bind.CallOpts{Context: ctx})
if err != nil {
return nil, nil, common.Address{}, utils.WrapError("Failed to fetch token contract", err)
}
Expand All @@ -184,7 +193,7 @@ func (r *ChainReader) GetStrategyAndUnderlyingERC20Token(
}

func (r *ChainReader) ServiceManagerCanSlashOperatorUntilBlock(
opts *bind.CallOpts,
ctx context.Context,
operatorAddr gethcommon.Address,
serviceManagerAddr gethcommon.Address,
) (uint32, error) {
Expand All @@ -193,20 +202,23 @@ func (r *ChainReader) ServiceManagerCanSlashOperatorUntilBlock(
}

return r.slasher.ContractCanSlashOperatorUntilBlock(
opts, operatorAddr, serviceManagerAddr,
&bind.CallOpts{Context: ctx}, operatorAddr, serviceManagerAddr,
)
}

func (r *ChainReader) OperatorIsFrozen(opts *bind.CallOpts, operatorAddr gethcommon.Address) (bool, error) {
func (r *ChainReader) OperatorIsFrozen(
ctx context.Context,
operatorAddr gethcommon.Address,
) (bool, error) {
if r.slasher == nil {
return false, errors.New("slasher contract not provided")
}

return r.slasher.IsFrozen(opts, operatorAddr)
return r.slasher.IsFrozen(&bind.CallOpts{Context: ctx}, operatorAddr)
}

func (r *ChainReader) GetOperatorSharesInStrategy(
opts *bind.CallOpts,
ctx context.Context,
operatorAddr gethcommon.Address,
strategyAddr gethcommon.Address,
) (*big.Int, error) {
Expand All @@ -215,92 +227,112 @@ func (r *ChainReader) GetOperatorSharesInStrategy(
}

return r.delegationManager.OperatorShares(
opts,
&bind.CallOpts{Context: ctx},
operatorAddr,
strategyAddr,
)
}

func (r *ChainReader) CalculateDelegationApprovalDigestHash(
opts *bind.CallOpts, staker gethcommon.Address, operator gethcommon.Address,
delegationApprover gethcommon.Address, approverSalt [32]byte, expiry *big.Int,
ctx context.Context,
staker gethcommon.Address,
operator gethcommon.Address,
delegationApprover gethcommon.Address,
approverSalt [32]byte,
expiry *big.Int,
) ([32]byte, error) {
if r.delegationManager == nil {
return [32]byte{}, errors.New("DelegationManager contract not provided")
}

return r.delegationManager.CalculateDelegationApprovalDigestHash(
opts, staker, operator, delegationApprover, approverSalt, expiry,
&bind.CallOpts{Context: ctx},
staker,
operator,
delegationApprover,
approverSalt,
expiry,
)
}

func (r *ChainReader) CalculateOperatorAVSRegistrationDigestHash(
opts *bind.CallOpts, operator gethcommon.Address, avs gethcommon.Address, salt [32]byte, expiry *big.Int,
ctx context.Context,
operator gethcommon.Address,
avs gethcommon.Address,
salt [32]byte,
expiry *big.Int,
) ([32]byte, error) {
if r.avsDirectory == nil {
return [32]byte{}, errors.New("AVSDirectory contract not provided")
}

return r.avsDirectory.CalculateOperatorAVSRegistrationDigestHash(
opts, operator, avs, salt, expiry,
&bind.CallOpts{Context: ctx},
operator,
avs,
salt,
expiry,
)
}

func (r *ChainReader) GetDistributionRootsLength(opts *bind.CallOpts) (*big.Int, error) {
func (r *ChainReader) GetDistributionRootsLength(ctx context.Context) (*big.Int, error) {
if r.rewardsCoordinator == nil {
return nil, errors.New("RewardsCoordinator contract not provided")
}

return r.rewardsCoordinator.GetDistributionRootsLength(opts)
return r.rewardsCoordinator.GetDistributionRootsLength(&bind.CallOpts{Context: ctx})
}

func (r *ChainReader) CurrRewardsCalculationEndTimestamp(opts *bind.CallOpts) (uint32, error) {
func (r *ChainReader) CurrRewardsCalculationEndTimestamp(ctx context.Context) (uint32, error) {
if r.rewardsCoordinator == nil {
return 0, errors.New("RewardsCoordinator contract not provided")
}

return r.rewardsCoordinator.CurrRewardsCalculationEndTimestamp(opts)
return r.rewardsCoordinator.CurrRewardsCalculationEndTimestamp(&bind.CallOpts{Context: ctx})
}

func (r *ChainReader) GetCurrentClaimableDistributionRoot(
opts *bind.CallOpts,
ctx context.Context,
) (rewardscoordinator.IRewardsCoordinatorDistributionRoot, error) {
if r.rewardsCoordinator == nil {
return rewardscoordinator.IRewardsCoordinatorDistributionRoot{}, errors.New(
"RewardsCoordinator contract not provided",
)
}

return r.rewardsCoordinator.GetCurrentClaimableDistributionRoot(opts)
return r.rewardsCoordinator.GetCurrentClaimableDistributionRoot(&bind.CallOpts{Context: ctx})
}

func (r *ChainReader) GetRootIndexFromHash(opts *bind.CallOpts, rootHash [32]byte) (uint32, error) {
func (r *ChainReader) GetRootIndexFromHash(
ctx context.Context,
rootHash [32]byte,
) (uint32, error) {
if r.rewardsCoordinator == nil {
return 0, errors.New("RewardsCoordinator contract not provided")
}

return r.rewardsCoordinator.GetRootIndexFromHash(opts, rootHash)
return r.rewardsCoordinator.GetRootIndexFromHash(&bind.CallOpts{Context: ctx}, rootHash)
}

func (r *ChainReader) GetCumulativeClaimed(
opts *bind.CallOpts,
ctx context.Context,
earner gethcommon.Address,
token gethcommon.Address,
) (*big.Int, error) {
if r.rewardsCoordinator == nil {
return nil, errors.New("RewardsCoordinator contract not provided")
}

return r.rewardsCoordinator.CumulativeClaimed(opts, earner, token)
return r.rewardsCoordinator.CumulativeClaimed(&bind.CallOpts{Context: ctx}, earner, token)
}

func (r *ChainReader) CheckClaim(
opts *bind.CallOpts,
ctx context.Context,
claim rewardscoordinator.IRewardsCoordinatorRewardsMerkleClaim,
) (bool, error) {
if r.rewardsCoordinator == nil {
return false, errors.New("RewardsCoordinator contract not provided")
}

return r.rewardsCoordinator.CheckClaim(opts, claim)
return r.rewardsCoordinator.CheckClaim(&bind.CallOpts{Context: ctx}, claim)
}
20 changes: 11 additions & 9 deletions chainio/clients/elcontracts/reader_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package elcontracts_test

import (
"context"
"math/big"
"testing"

Expand All @@ -15,20 +16,21 @@ import (

func TestChainReader(t *testing.T) {
clients, anvilHttpEndpoint := testclients.BuildTestClients(t)
ctx := context.Background()

contractAddrs := testutils.GetContractAddressesFromContractRegistry(anvilHttpEndpoint)
operator := types.Operator{
Address: "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266",
}

t.Run("is operator registered", func(t *testing.T) {
isOperator, err := clients.ElChainReader.IsOperatorRegistered(&bind.CallOpts{}, operator)
isOperator, err := clients.ElChainReader.IsOperatorRegistered(ctx, operator)
assert.NoError(t, err)
assert.Equal(t, isOperator, true)
})

t.Run("get operator details", func(t *testing.T) {
operatorDetails, err := clients.ElChainReader.GetOperatorDetails(&bind.CallOpts{}, operator)
operatorDetails, err := clients.ElChainReader.GetOperatorDetails(ctx, operator)
assert.NoError(t, err)
assert.NotNil(t, operatorDetails)
assert.Equal(t, operator.Address, operatorDetails.Address)
Expand All @@ -37,7 +39,7 @@ func TestChainReader(t *testing.T) {
t.Run("get strategy and underlying token", func(t *testing.T) {
strategyAddr := contractAddrs.Erc20MockStrategy
strategy, underlyingTokenAddr, err := clients.ElChainReader.GetStrategyAndUnderlyingToken(
&bind.CallOpts{},
ctx,
strategyAddr,
)
assert.NoError(t, err)
Expand All @@ -55,7 +57,7 @@ func TestChainReader(t *testing.T) {
t.Run("get strategy and underlying ERC20 token", func(t *testing.T) {
strategyAddr := contractAddrs.Erc20MockStrategy
strategy, contractUnderlyingToken, underlyingTokenAddr, err := clients.ElChainReader.GetStrategyAndUnderlyingERC20Token(
&bind.CallOpts{},
ctx,
strategyAddr,
)
assert.NoError(t, err)
Expand All @@ -70,7 +72,7 @@ func TestChainReader(t *testing.T) {

t.Run("service manager can slash operator until block", func(t *testing.T) {
_, err := clients.ElChainReader.ServiceManagerCanSlashOperatorUntilBlock(
&bind.CallOpts{},
ctx,
common.HexToAddress(operator.Address),
contractAddrs.ServiceManager,
)
Expand All @@ -79,7 +81,7 @@ func TestChainReader(t *testing.T) {

t.Run("operator is frozen", func(t *testing.T) {
isFrozen, err := clients.ElChainReader.OperatorIsFrozen(
&bind.CallOpts{},
ctx,
common.HexToAddress(operator.Address),
)
assert.NoError(t, err)
Expand All @@ -88,7 +90,7 @@ func TestChainReader(t *testing.T) {

t.Run("get operator shares in strategy", func(t *testing.T) {
shares, err := clients.ElChainReader.GetOperatorSharesInStrategy(
&bind.CallOpts{},
ctx,
common.HexToAddress(operator.Address),
contractAddrs.Erc20MockStrategy,
)
Expand All @@ -102,7 +104,7 @@ func TestChainReader(t *testing.T) {
approverSalt := [32]byte{}
expiry := big.NewInt(0)
digest, err := clients.ElChainReader.CalculateDelegationApprovalDigestHash(
&bind.CallOpts{},
ctx,
staker,
common.HexToAddress(operator.Address),
delegationApprover,
Expand All @@ -118,7 +120,7 @@ func TestChainReader(t *testing.T) {
salt := [32]byte{}
expiry := big.NewInt(0)
digest, err := clients.ElChainReader.CalculateOperatorAVSRegistrationDigestHash(
&bind.CallOpts{},
ctx,
common.HexToAddress(operator.Address),
avs,
salt,
Expand Down
5 changes: 2 additions & 3 deletions chainio/clients/elcontracts/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"math/big"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
gethcommon "github.com/ethereum/go-ethereum/common"
gethtypes "github.com/ethereum/go-ethereum/core/types"

Expand All @@ -27,7 +26,7 @@ import (

type Reader interface {
GetStrategyAndUnderlyingERC20Token(
opts *bind.CallOpts, strategyAddr gethcommon.Address,
ctx context.Context, strategyAddr gethcommon.Address,
) (*strategy.ContractIStrategy, erc20.ContractIERC20Methods, gethcommon.Address, error)
}

Expand Down Expand Up @@ -276,7 +275,7 @@ func (w *ChainWriter) DepositERC20IntoStrategy(
return nil, err
}
_, underlyingTokenContract, underlyingTokenAddr, err := w.elChainReader.GetStrategyAndUnderlyingERC20Token(
&bind.CallOpts{Context: ctx},
ctx,
strategyAddr,
)
if err != nil {
Expand Down
Loading

0 comments on commit 565bb44

Please sign in to comment.