diff --git a/pkg/p2p/protobuf/protobuf.go b/pkg/p2p/protobuf/protobuf.go new file mode 100644 index 00000000..e5b69d20 --- /dev/null +++ b/pkg/p2p/protobuf/protobuf.go @@ -0,0 +1,70 @@ +package protobuf + +import ( + "context" + "fmt" + + "github.com/libp2p/go-libp2p/p2p/host/autonat/pb" + "github.com/primevprotocol/mev-commit/pkg/p2p" + "google.golang.org/protobuf/proto" +) + +type Encoder interface { + ReadMsg(context.Context, *pb.Message) error + WriteMsg(context.Context, *pb.Message) error +} + +type protobuf struct { + p2p.Stream +} + +func NewReaderWriter(s p2p.Stream) Encoder { + return &protobuf{s} +} + +func (p *protobuf) ReadMsg(ctx context.Context, msg *pb.Message) error { + type result struct { + msgBuf []byte + err error + } + + resultC := make(chan result, 1) + go func() { + msgBuf, err := p.Stream.ReadMsg() + resultC <- result{msgBuf: msgBuf, err: err} + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case res := <-resultC: + if res.err != nil { + return fmt.Errorf("failed to read msg: %w", res.err) + } + + if err := proto.Unmarshal(res.msgBuf, msg); err != nil { + return fmt.Errorf("failed to unmarshal message: %w", err) + } + + return nil + } +} + +func (p *protobuf) WriteMsg(ctx context.Context, msg *pb.Message) error { + msgBuf, err := proto.Marshal(msg) + if err != nil { + return fmt.Errorf("failed marshaling message: %w", err) + } + + errC := make(chan error, 1) + go func() { + errC <- p.Stream.WriteMsg(msgBuf) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errC: + return err + } +} diff --git a/pkg/p2p/protobuf/protobuf_test.go b/pkg/p2p/protobuf/protobuf_test.go new file mode 100644 index 00000000..20209a5d --- /dev/null +++ b/pkg/p2p/protobuf/protobuf_test.go @@ -0,0 +1,61 @@ +package protobuf_test + +import ( + "bytes" + "context" + "testing" + + "github.com/libp2p/go-libp2p/p2p/host/autonat/pb" + "github.com/primevprotocol/mev-commit/pkg/p2p/protobuf" + p2ptest "github.com/primevprotocol/mev-commit/pkg/p2p/testing" + "google.golang.org/protobuf/proto" +) + +func TestProtobufEncodingDecoding(t *testing.T) { + t.Parallel() + + t.Run("ok", func(t *testing.T) { + out, in := p2ptest.NewDuplexStream() + + test := &pb.Message{ + Type: pb.Message_DIAL.Enum(), + Dial: &pb.Message_Dial{ + Peer: &pb.Message_PeerInfo{ + Id: []byte("16Uiu2HAmK8EQ9axsSaE9hqjdHX7Hq5Jbeo2tmuNcLHwyQLWKjSYw"), + Addrs: [][]byte{ + []byte("0x9Bbc6Bef724d483C8f834C03fC2D3FE115D47ABF"), + []byte("0x903e2Abdc0fF09aBCB4C23CD8Ef1e267dfD32c2C"), + []byte("0xdCFA8524A3A266A388A4884cB6448463ae19D025"), + }, + }, + }, + } + + reader := protobuf.NewReaderWriter(in) + writer := protobuf.NewReaderWriter(out) + + if err := writer.WriteMsg(context.Background(), test); err != nil { + t.Fatal(err) + } + + var res pb.Message + err := reader.ReadMsg(context.Background(), &res) + if err != nil { + t.Fatal(err) + } + + testBytes, err := proto.Marshal(test) + if err != nil { + t.Fatal(err) + } + + resBytes, err := proto.Marshal(&res) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(testBytes, resBytes) { + t.Fatalf("expected %v, got %v", testBytes, resBytes) + } + }) +}