From fc7aa9af59e69a75d072f85a353732f05c3942bb Mon Sep 17 00:00:00 2001 From: Yannick Bensacq Date: Mon, 23 Sep 2024 14:23:09 +0200 Subject: [PATCH] feat(protocol): data message --- src/network/protocol/messages/getdata.zig | 158 ++++++++++++++++++++++ src/network/protocol/messages/lib.zig | 6 + src/network/wire/lib.zig | 2 + 3 files changed, 166 insertions(+) create mode 100644 src/network/protocol/messages/getdata.zig diff --git a/src/network/protocol/messages/getdata.zig b/src/network/protocol/messages/getdata.zig new file mode 100644 index 0000000..4148fb5 --- /dev/null +++ b/src/network/protocol/messages/getdata.zig @@ -0,0 +1,158 @@ +const std = @import("std"); +const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; + +const Sha256 = std.crypto.hash.sha2.Sha256; + +const protocol = @import("../lib.zig"); + +pub const GetdataMessage = struct { + inventory: []const InventoryItem, + + pub const InventoryItem = struct { + type: u32, + hash: [32]u8, + }; + + pub inline fn name() *const [12]u8 { + return protocol.CommandNames.GETDATA ++ [_]u8{0} ** 5; + } + + /// Returns the message checksum + /// + /// Computed as `Sha256(Sha256(self.serialize()))[0..4]` + pub fn checksum(self: GetdataMessage) [4]u8 { + var digest: [32]u8 = undefined; + var hasher = Sha256.init(.{}); + const writer = hasher.writer(); + self.serializeToWriter(writer) catch unreachable; // Sha256.write is infaible + hasher.final(&digest); + + Sha256.hash(&digest, &digest, .{}); + + return digest[0..4].*; + } + + /// Serialize the message as bytes and write them to the Writer. + /// + /// `w` should be a valid `Writer`. + pub fn serializeToWriter(self: *const GetdataMessage, w: anytype) !void { + comptime { + if (!std.meta.hasFn(@TypeOf(w), "writeInt")) @compileError("Expects writer to have fn 'writeInt'."); + if (!std.meta.hasFn(@TypeOf(w), "writeAll")) @compileError("Expects writer to have fn 'writeAll'."); + } + + const count = CompactSizeUint.new(self.inventory.len); + try count.encodeToWriter(w); + + for (self.inventory) |item| { + try w.writeInt(u32, item.type, .little); + + try w.writeAll(&item.hash); + } + } + + pub fn serialize(self: *const GetdataMessage, allocator: std.mem.Allocator) ![]u8 { + const serialized_len = self.hintSerializedLen(); + + const ret = try allocator.alloc(u8, serialized_len); + errdefer allocator.free(ret); + + try self.serializeToSlice(ret); + + return ret; + } + + /// Serialize a message as bytes and write them to the buffer. + /// + /// buffer.len must be >= than self.hintSerializedLen() + pub fn serializeToSlice(self: *const GetdataMessage, buffer: []u8) !void { + var fbs = std.io.fixedBufferStream(buffer); + const writer = fbs.writer(); + try self.serializeToWriter(writer); + } + + pub fn hintSerializedLen(self: *const GetdataMessage) usize { + var length: usize = 0; + + // Adding the length of CompactSizeUint for the count + const count = CompactSizeUint.new(self.inventory.len); + length += count.hint_encoded_len(); + + // Adding the length of each inventory item + length += self.inventory.len * (4 + 32); // Type (4 bytes) + Hash (32 bytes) + + return length; + } + + pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !GetdataMessage { + comptime { + if (!std.meta.hasFn(@TypeOf(r), "readInt")) @compileError("Expects reader to have fn 'readInt'."); + if (!std.meta.hasFn(@TypeOf(r), "readNoEof")) @compileError("Expects reader to have fn 'readNoEof'."); + } + + const compact_count = try CompactSizeUint.decodeReader(r); + const count = compact_count.value(); + + const inventory = try allocator.alloc(GetdataMessage.InventoryItem, count); + + for (inventory) |*item| { + item.type = try r.readInt(u32, .little); + try r.readNoEof(&item.hash); + } + + return GetdataMessage{ + .inventory = inventory, + }; + } + + pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetdataMessage { + var fbs = std.io.fixedBufferStream(bytes); + const reader = fbs.reader(); + return try GetdataMessage.deserializeReader(allocator, reader); + } + + + pub fn eql(self: *const GetdataMessage, other: *const GetdataMessage) bool { + if (self.inventory.len != other.inventory.len) return false; + + var i: usize = 0; + for (self.inventory) |item| { + if (item.type != other.inventory[i].type) return false; + if (!std.mem.eql(u8, &item.hash, &other.inventory[i].hash)) return false; + i += 1; + } + return true; + } +}; + + +// TESTS + +test "ok_full_flow_GetdataMessage" { + const allocator = std.testing.allocator; + + // With some inventory items + { + const inventory_items = [_]GetdataMessage.InventoryItem{ + .{ .type = 1, .hash = [_]u8{0xab} ** 32 }, + .{ .type = 2, .hash = [_]u8{0xcd} ** 32 }, + }; + + const gd = GetdataMessage{ + .inventory = inventory_items[0..], + }; + + // Serialize + const payload = try gd.serialize(allocator); + defer allocator.free(payload); + + // Deserialize + const deserialized_gd = try GetdataMessage.deserializeSlice(allocator, payload); + + // Test equality + try std.testing.expect(gd.eql(&deserialized_gd)); + + // Free allocated memory for deserialized inventory + defer allocator.free(deserialized_gd.inventory); + } +} \ No newline at end of file diff --git a/src/network/protocol/messages/lib.zig b/src/network/protocol/messages/lib.zig index 4e9673f..255f92c 100644 --- a/src/network/protocol/messages/lib.zig +++ b/src/network/protocol/messages/lib.zig @@ -5,6 +5,7 @@ pub const MempoolMessage = @import("mempool.zig").MempoolMessage; pub const GetaddrMessage = @import("getaddr.zig").GetaddrMessage; pub const GetblocksMessage = @import("getblocks.zig").GetblocksMessage; pub const PingMessage = @import("ping.zig").PingMessage; +pub const GetdataMessage = @import("getdata.zig").GetdataMessage; pub const MessageTypes = enum { Version, @@ -13,6 +14,7 @@ pub const MessageTypes = enum { Getaddr, Getblocks, Ping, + Getdata, }; pub const Message = union(MessageTypes) { @@ -22,6 +24,7 @@ pub const Message = union(MessageTypes) { Getaddr: GetaddrMessage, Getblocks: GetblocksMessage, Ping: PingMessage, + Getdata: GetdataMessage, pub fn deinit(self: Message, allocator: std.mem.Allocator) void { switch (self) { @@ -31,6 +34,7 @@ pub const Message = union(MessageTypes) { .Getaddr => {}, .Getblocks => |m| m.deinit(allocator), .Ping => {}, + .Getdata => {}, } } pub fn checksum(self: Message) [4]u8 { @@ -41,6 +45,7 @@ pub const Message = union(MessageTypes) { .Getaddr => |m| m.checksum(), .Getblocks => |m| m.checksum(), .Ping => |m| m.checksum(), + .Getdata => |m| m.checksum(), }; } @@ -52,6 +57,7 @@ pub const Message = union(MessageTypes) { .Getaddr => |m| m.hintSerializedLen(), .Getblocks => |m| m.hintSerializedLen(), .Ping => |m| m.hintSerializedLen(), + .Getdata => |m| m.hintSerializedLen(), }; } }; diff --git a/src/network/wire/lib.zig b/src/network/wire/lib.zig index 879650b..9e034a5 100644 --- a/src/network/wire/lib.zig +++ b/src/network/wire/lib.zig @@ -100,6 +100,8 @@ pub fn receiveMessage(allocator: std.mem.Allocator, r: anytype) !protocol.messag protocol.messages.Message{ .Getblocks = try protocol.messages.GetblocksMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.PingMessage.name())) protocol.messages.Message{ .Ping = try protocol.messages.PingMessage.deserializeReader(allocator, r) } + else if (std.mem.eql(u8, &command, protocol.messages.GetdataMessage.name())) + protocol.messages.Message{ .Getdata = try protocol.messages.GetdataMessage.deserializeReader(allocator, r) } else return error.UnknownMessage; errdefer message.deinit(allocator);