diff --git a/src/network/protocol/messages/getdata.zig b/src/network/protocol/messages/getdata.zig index 3ef18d6..d0979ec 100644 --- a/src/network/protocol/messages/getdata.zig +++ b/src/network/protocol/messages/getdata.zig @@ -2,6 +2,8 @@ const std = @import("std"); const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; const message = @import("./lib.zig"); const genericChecksum = @import("lib.zig").genericChecksum; +const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice; +const genericSerialize = @import("lib.zig").genericSerialize; const Sha256 = std.crypto.hash.sha2.Sha256; @@ -9,6 +11,7 @@ const protocol = @import("../lib.zig"); pub const GetdataMessage = struct { inventory: []const protocol.InventoryItem, + const Self = @This(); pub inline fn name() *const [12]u8 { return protocol.CommandNames.GETDATA ++ [_]u8{0} ** 5; @@ -39,14 +42,7 @@ pub const GetdataMessage = struct { } pub fn serialize(self: *const GetdataMessage, allocator: std.mem.Allocator) ![]u8 { - const serialized_len = self.hintSerializedLen(); - - const ret = try allocator.alloc(u8, serialized_len); - errdefer allocator.free(ret); - - try self.serializeToSlice(ret); - - return ret; + return genericSerialize(self, allocator); } /// Serialize a message as bytes and write them to the buffer. @@ -94,10 +90,8 @@ pub const GetdataMessage = struct { } /// Deserialize bytes into a `GetdataMessage` - pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetdataMessage { - var fbs = std.io.fixedBufferStream(bytes); - const reader = fbs.reader(); - return try GetdataMessage.deserializeReader(allocator, reader); + pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self { + return genericDeserializeSlice(Self, allocator, bytes); } diff --git a/src/network/protocol/messages/inv.zig b/src/network/protocol/messages/inv.zig new file mode 100644 index 0000000..e247586 --- /dev/null +++ b/src/network/protocol/messages/inv.zig @@ -0,0 +1,146 @@ +const std = @import("std"); +const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint; +const message = @import("./lib.zig"); +const genericChecksum = @import("lib.zig").genericChecksum; +const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice; + +const Sha256 = std.crypto.hash.sha2.Sha256; + +const protocol = @import("../lib.zig"); + +pub const InvMessage = struct { + inventory: []const protocol.InventoryItem, + const Self = @This(); + + pub inline fn name() *const [12]u8 { + return protocol.CommandNames.INV ++ [_]u8{0} ** 5; + } + + /// Returns the message checksum + /// + /// Computed as `Sha256(Sha256(self.serialize()))[0..4]` + pub fn checksum(self: *const InvMessage) [4]u8 { + return genericChecksum(self); + } + + /// Free the `inventory` + pub fn deinit(self: InvMessage, allocator: std.mem.Allocator) void { + allocator.free(self.inventory); + } + + /// Serialize the message as bytes and write them to the Writer. + /// + /// `w` should be a valid `Writer`. + pub fn serializeToWriter(self: *const InvMessage, w: anytype) !void { + const count = CompactSizeUint.new(self.inventory.len); + try count.encodeToWriter(w); + + for (self.inventory) |item| { + try item.encodeToWriter(w); + } + } + + pub fn serialize(self: *const InvMessage, allocator: std.mem.Allocator) ![]u8 { + const serialized_len = self.hintSerializedLen(); + + const ret = try allocator.alloc(u8, serialized_len); + errdefer allocator.free(ret); + + try self.serializeToSlice(ret); + + return ret; + } + + /// Serialize a message as bytes and write them to the buffer. + /// + /// buffer.len must be >= than self.hintSerializedLen() + pub fn serializeToSlice(self: *const InvMessage, buffer: []u8) !void { + var fbs = std.io.fixedBufferStream(buffer); + const writer = fbs.writer(); + try self.serializeToWriter(writer); + } + + pub fn hintSerializedLen(self: *const InvMessage) usize { + var length: usize = 0; + + // Adding the length of CompactSizeUint for the count + const count = CompactSizeUint.new(self.inventory.len); + length += count.hint_encoded_len(); + + // Adding the length of each inventory item + length += self.inventory.len * (4 + 32); // Type (4 bytes) + Hash (32 bytes) + + return length; + } + + pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !InvMessage { + + const compact_count = try CompactSizeUint.decodeReader(r); + const count = compact_count.value(); + if (count == 0) { + return InvMessage{ + .inventory = &[_]protocol.InventoryItem{}, + }; + } + + const inventory = try allocator.alloc(protocol.InventoryItem, count); + errdefer allocator.free(inventory); + + for (inventory) |*item| { + item.* = try protocol.InventoryItem.decodeReader(r); + } + + return InvMessage{ + .inventory = inventory, + }; + } + + /// Deserialize bytes into a `InvMessage` + pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self { + return genericDeserializeSlice(Self, allocator, bytes); + } + + + pub fn eql(self: *const InvMessage, other: *const InvMessage) bool { + if (self.inventory.len != other.inventory.len) return false; + + for (0..self.inventory.len) |i| { + const item_self = self.inventory[i]; + const item_other = other.inventory[i]; + if (!item_self.eql(&item_other)) { + return false; + } + } + + return true; + } +}; + + +// TESTS +test "ok_full_flow_inv_message" { + const allocator = std.testing.allocator; + + // With some inventory items + { + const inventory_items = [_]protocol.InventoryItem{ + .{ .type = 1, .hash = [_]u8{0xab} ** 32 }, + .{ .type = 2, .hash = [_]u8{0xcd} ** 32 }, + .{ .type = 2, .hash = [_]u8{0xef} ** 32 }, + }; + + const gd = InvMessage{ + .inventory = inventory_items[0..], + }; + + const payload = try gd.serialize(allocator); + defer allocator.free(payload); + + const deserialized_gd = try InvMessage.deserializeSlice(allocator, payload); + + try std.testing.expect(gd.eql(&deserialized_gd)); + + // Free allocated memory for deserialized inventory + defer allocator.free(deserialized_gd.inventory); + } +} \ No newline at end of file diff --git a/src/network/protocol/messages/lib.zig b/src/network/protocol/messages/lib.zig index 7395c1c..3d5e411 100644 --- a/src/network/protocol/messages/lib.zig +++ b/src/network/protocol/messages/lib.zig @@ -20,6 +20,7 @@ pub const SendHeadersMessage = @import("sendheaders.zig").SendHeadersMessage; pub const FilterLoadMessage = @import("filterload.zig").FilterLoadMessage; pub const HeadersMessage = @import("headers.zig").HeadersMessage; pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage; +pub const InvMessage = @import("inv.zig").InvMessage; pub const MessageTypes = enum { version, @@ -41,6 +42,7 @@ pub const MessageTypes = enum { getdata, headers, cmpctblock, + inv }; @@ -64,6 +66,7 @@ pub const Message = union(MessageTypes) { getdata: GetdataMessage, headers: HeadersMessage, cmpctblock: CmpctBlockMessage, + inv: InvMessage, pub fn name(self: Message) *const [12]u8 { return switch (self) { @@ -86,6 +89,7 @@ pub const Message = union(MessageTypes) { .getdata => |m| @TypeOf(m).name(), .headers => |m| @TypeOf(m).name(), .cmpctblock => |m| @TypeOf(m).name(), + .inv => |m| @TypeOf(m).name(), }; } @@ -99,6 +103,7 @@ pub const Message = union(MessageTypes) { .getdata => |*m| m.deinit(allocator), .cmpctblock => |*m| m.deinit(allocator), .headers => |*m| m.deinit(allocator), + .inv => |*m| m.deinit(allocator), else => {} } } @@ -124,6 +129,7 @@ pub const Message = union(MessageTypes) { .getdata => |*m| m.checksum(), .headers => |*m| m.checksum(), .cmpctblock => |*m| m.checksum(), + .inv => |*m| m.checksum(), }; } @@ -148,6 +154,7 @@ pub const Message = union(MessageTypes) { .getdata => |m| m.hintSerializedLen(), .headers => |*m| m.hintSerializedLen(), .cmpctblock => |*m| m.hintSerializedLen(), + .inv => |*m| m.hintSerializedLen(), }; } }; diff --git a/src/network/wire/lib.zig b/src/network/wire/lib.zig index 85163ff..c67a367 100644 --- a/src/network/wire/lib.zig +++ b/src/network/wire/lib.zig @@ -145,6 +145,8 @@ pub fn receiveMessage( protocol.messages.Message{ .getdata = try protocol.messages.GetdataMessage.deserializeReader(allocator, r) } else if (std.mem.eql(u8, &command, protocol.messages.CmpctBlockMessage.name())) protocol.messages.Message{ .cmpctblock = try protocol.messages.CmpctBlockMessage.deserializeReader(allocator, r) } + else if (std.mem.eql(u8, &command, protocol.messages.InvMessage.name())) + protocol.messages.Message{ .inv = try protocol.messages.InvMessage.deserializeReader(allocator, r) } else { try r.skipBytes(payload_len, .{}); // Purge the wire return error.UnknownMessage; @@ -690,3 +692,41 @@ test "ok_send_cmpctblock_message" { try std.testing.expect(hint_len > 0); try std.testing.expect(hint_len == serialized.len); } + +test "ok_send_inv_message" { + const Config = @import("../../config/config.zig").Config; + const ArrayList = std.ArrayList; + const test_allocator = std.testing.allocator; + const InvMessage = protocol.messages.InvMessage; + + var list: std.ArrayListAligned(u8, null) = ArrayList(u8).init(test_allocator); + defer list.deinit(); + + const inventory = try test_allocator.alloc(protocol.InventoryItem, 5); + defer test_allocator.free(inventory); + + for (inventory) |*item| { + item.type = 1; + for (&item.hash) |*byte| { + byte.* = 0xab; + } + } + + const message = InvMessage{ + .inventory = inventory, + }; + + var received_message = try write_and_read_message( + test_allocator, + &list, + Config.BitcoinNetworkId.MAINNET, + Config.PROTOCOL_VERSION, + message, + ) orelse unreachable; + defer received_message.deinit(test_allocator); + + switch (received_message) { + .inv => |rm| try std.testing.expect(message.eql(&rm)), + else => unreachable, + } +}