From 006f5ab9242c953429353cccf76a714aff320117 Mon Sep 17 00:00:00 2001 From: alok Date: Tue, 3 Oct 2023 03:02:11 +0530 Subject: [PATCH] fix: two-way handshake --- .../libp2p/internal/handshake/handshake.go | 88 +++++++++++++------ 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/pkg/p2p/libp2p/internal/handshake/handshake.go b/pkg/p2p/libp2p/internal/handshake/handshake.go index c882d555..dd956ea7 100644 --- a/pkg/p2p/libp2p/internal/handshake/handshake.go +++ b/pkg/p2p/libp2p/internal/handshake/handshake.go @@ -73,10 +73,9 @@ type HandshakeReq struct { type HandshakeResp struct { ObservedAddress common.Address PeerType string - Ack *HandshakeReq } -func (h *Service) verifySignature( +func (h *Service) verifyReq( req *HandshakeReq, peerID core.PeerID, ) (common.Address, error) { @@ -100,6 +99,17 @@ func (h *Service) verifySignature( return common.Address{}, errors.New("observed address mismatch") } + if req.PeerType == "builder" { + stake, err := h.register.GetStake(ethAddress) + if err != nil { + return common.Address{}, err + } + + if stake.Cmp(h.minimumStake) < 0 { + return common.Address{}, errors.New("stake insufficient") + } + } + return ethAddress, nil } @@ -113,6 +123,18 @@ func (h *Service) createSignature() ([]byte, error) { return sig, nil } +func (h *Service) verifyResp(resp *HandshakeResp) error { + if !bytes.Equal(resp.ObservedAddress.Bytes(), h.ethAddress.Bytes()) { + return errors.New("observed address mismatch") + } + + if resp.PeerType != h.peerType.String() { + return errors.New("peer type mismatch") + } + + return nil +} + func (h *Service) Handle( ctx context.Context, stream p2p.Stream, @@ -125,22 +147,11 @@ func (h *Service) Handle( return p2p.Peer{}, err } - ethAddress, err := h.verifySignature(req, peerID) + ethAddress, err := h.verifyReq(req, peerID) if err != nil { return p2p.Peer{}, err } - if req.PeerType == "builder" { - stake, err := h.register.GetStake(ethAddress) - if err != nil { - return p2p.Peer{}, err - } - - if stake.Cmp(h.minimumStake) < 0 { - return p2p.Peer{}, errors.New("stake insufficient") - } - } - sig, err := h.createSignature() if err != nil { return p2p.Peer{}, err @@ -149,17 +160,33 @@ func (h *Service) Handle( resp := &HandshakeResp{ ObservedAddress: ethAddress, PeerType: req.PeerType, - Ack: &HandshakeReq{ - PeerType: h.peerType.String(), - Token: h.passcode, - Sig: sig, - }, } if err := w.WriteMsg(ctx, resp); err != nil { return p2p.Peer{}, err } + ar, aw := msgpack.NewReaderWriter[HandshakeResp, HandshakeReq](stream) + + err = aw.WriteMsg(ctx, &HandshakeReq{ + PeerType: h.peerType.String(), + Token: h.passcode, + Sig: sig, + }, + ) + if err != nil { + return p2p.Peer{}, err + } + + ack, err := ar.ReadMsg(ctx) + if err != nil { + return p2p.Peer{}, err + } + + if err := h.verifyResp(ack); err != nil { + return p2p.Peer{}, err + } + return p2p.Peer{ EthAddress: ethAddress, Type: p2p.FromString(req.PeerType), @@ -194,25 +221,32 @@ func (h *Service) Handshake( return p2p.Peer{}, err } - if !bytes.Equal(resp.ObservedAddress.Bytes(), h.ethAddress.Bytes()) { - return p2p.Peer{}, errors.New("observed address mismatch") + if err := h.verifyResp(resp); err != nil { + return p2p.Peer{}, err } - if resp.PeerType != h.peerType.String() { - return p2p.Peer{}, errors.New("peer type mismatch") + ar, aw := msgpack.NewReaderWriter[HandshakeReq, HandshakeResp](stream) + + ack, err := ar.ReadMsg(ctx) + if err != nil { + return p2p.Peer{}, err } - if resp.Ack == nil { - return p2p.Peer{}, errors.New("ack not received") + ethAddress, err := h.verifyReq(ack, peerID) + if err != nil { + return p2p.Peer{}, err } - ethAddress, err := h.verifySignature(resp.Ack, peerID) + err = aw.WriteMsg(ctx, &HandshakeResp{ + ObservedAddress: ethAddress, + PeerType: ack.PeerType, + }) if err != nil { return p2p.Peer{}, err } return p2p.Peer{ EthAddress: ethAddress, - Type: p2p.FromString(resp.Ack.PeerType), + Type: p2p.FromString(ack.PeerType), }, nil }