Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(protocol/message): add inventory message #161

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions src/network/protocol/messages/getdata.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ const std = @import("std");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const message = @import("./lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;
const genericSerialize = @import("lib.zig").genericSerialize;

const Sha256 = std.crypto.hash.sha2.Sha256;

const protocol = @import("../lib.zig");

pub const GetdataMessage = struct {
inventory: []const protocol.InventoryItem,
const Self = @This();

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.GETDATA ++ [_]u8{0} ** 5;
Expand Down Expand Up @@ -39,14 +42,7 @@ pub const GetdataMessage = struct {
}

pub fn serialize(self: *const GetdataMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
return genericSerialize(self, allocator);
}

/// Serialize a message as bytes and write them to the buffer.
Expand Down Expand Up @@ -94,10 +90,8 @@ pub const GetdataMessage = struct {
}

/// Deserialize bytes into a `GetdataMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetdataMessage {
var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();
return try GetdataMessage.deserializeReader(allocator, reader);
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
return genericDeserializeSlice(Self, allocator, bytes);
}


Expand Down
146 changes: 146 additions & 0 deletions src/network/protocol/messages/inv.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
const std = @import("std");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const message = @import("./lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

const Sha256 = std.crypto.hash.sha2.Sha256;

const protocol = @import("../lib.zig");

pub const InvMessage = struct {
inventory: []const protocol.InventoryItem,
const Self = @This();

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.INV ++ [_]u8{0} ** 5;
}

/// Returns the message checksum
///
/// Computed as `Sha256(Sha256(self.serialize()))[0..4]`
pub fn checksum(self: *const InvMessage) [4]u8 {
return genericChecksum(self);
}

/// Free the `inventory`
pub fn deinit(self: InvMessage, allocator: std.mem.Allocator) void {
allocator.free(self.inventory);
}

/// Serialize the message as bytes and write them to the Writer.
///
/// `w` should be a valid `Writer`.
pub fn serializeToWriter(self: *const InvMessage, w: anytype) !void {
const count = CompactSizeUint.new(self.inventory.len);
try count.encodeToWriter(w);

for (self.inventory) |item| {
try item.encodeToWriter(w);
}
}

pub fn serialize(self: *const InvMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const InvMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
}

pub fn hintSerializedLen(self: *const InvMessage) usize {
var length: usize = 0;

// Adding the length of CompactSizeUint for the count
const count = CompactSizeUint.new(self.inventory.len);
length += count.hint_encoded_len();

// Adding the length of each inventory item
length += self.inventory.len * (4 + 32); // Type (4 bytes) + Hash (32 bytes)

return length;
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !InvMessage {

const compact_count = try CompactSizeUint.decodeReader(r);
const count = compact_count.value();
if (count == 0) {
return InvMessage{
.inventory = &[_]protocol.InventoryItem{},
};
}

const inventory = try allocator.alloc(protocol.InventoryItem, count);
errdefer allocator.free(inventory);

for (inventory) |*item| {
item.* = try protocol.InventoryItem.decodeReader(r);
}

return InvMessage{
.inventory = inventory,
};
}

/// Deserialize bytes into a `InvMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
return genericDeserializeSlice(Self, allocator, bytes);
}


pub fn eql(self: *const InvMessage, other: *const InvMessage) bool {
if (self.inventory.len != other.inventory.len) return false;

for (0..self.inventory.len) |i| {
const item_self = self.inventory[i];
const item_other = other.inventory[i];
if (!item_self.eql(&item_other)) {
return false;
}
}

return true;
}
};


// TESTS
test "ok_full_flow_inv_message" {
const allocator = std.testing.allocator;

// With some inventory items
{
const inventory_items = [_]protocol.InventoryItem{
.{ .type = 1, .hash = [_]u8{0xab} ** 32 },
.{ .type = 2, .hash = [_]u8{0xcd} ** 32 },
.{ .type = 2, .hash = [_]u8{0xef} ** 32 },
};

const gd = InvMessage{
.inventory = inventory_items[0..],
};

const payload = try gd.serialize(allocator);
defer allocator.free(payload);

const deserialized_gd = try InvMessage.deserializeSlice(allocator, payload);

try std.testing.expect(gd.eql(&deserialized_gd));

// Free allocated memory for deserialized inventory
defer allocator.free(deserialized_gd.inventory);
}
}
7 changes: 7 additions & 0 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub const SendHeadersMessage = @import("sendheaders.zig").SendHeadersMessage;
pub const FilterLoadMessage = @import("filterload.zig").FilterLoadMessage;
pub const HeadersMessage = @import("headers.zig").HeadersMessage;
pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage;
pub const InvMessage = @import("inv.zig").InvMessage;

pub const MessageTypes = enum {
version,
Expand All @@ -41,6 +42,7 @@ pub const MessageTypes = enum {
getdata,
headers,
cmpctblock,
inv
};


Expand All @@ -64,6 +66,7 @@ pub const Message = union(MessageTypes) {
getdata: GetdataMessage,
headers: HeadersMessage,
cmpctblock: CmpctBlockMessage,
inv: InvMessage,

pub fn name(self: Message) *const [12]u8 {
return switch (self) {
Expand All @@ -86,6 +89,7 @@ pub const Message = union(MessageTypes) {
.getdata => |m| @TypeOf(m).name(),
.headers => |m| @TypeOf(m).name(),
.cmpctblock => |m| @TypeOf(m).name(),
.inv => |m| @TypeOf(m).name(),
};
}

Expand All @@ -99,6 +103,7 @@ pub const Message = union(MessageTypes) {
.getdata => |*m| m.deinit(allocator),
.cmpctblock => |*m| m.deinit(allocator),
.headers => |*m| m.deinit(allocator),
.inv => |*m| m.deinit(allocator),
else => {}
}
}
Expand All @@ -124,6 +129,7 @@ pub const Message = union(MessageTypes) {
.getdata => |*m| m.checksum(),
.headers => |*m| m.checksum(),
.cmpctblock => |*m| m.checksum(),
.inv => |*m| m.checksum(),
};
}

Expand All @@ -148,6 +154,7 @@ pub const Message = union(MessageTypes) {
.getdata => |m| m.hintSerializedLen(),
.headers => |*m| m.hintSerializedLen(),
.cmpctblock => |*m| m.hintSerializedLen(),
.inv => |*m| m.hintSerializedLen(),
};
}
};
Expand Down
40 changes: 40 additions & 0 deletions src/network/wire/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pub fn receiveMessage(
protocol.messages.Message{ .getdata = try protocol.messages.GetdataMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.CmpctBlockMessage.name()))
protocol.messages.Message{ .cmpctblock = try protocol.messages.CmpctBlockMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.InvMessage.name()))
protocol.messages.Message{ .inv = try protocol.messages.InvMessage.deserializeReader(allocator, r) }
else {
try r.skipBytes(payload_len, .{}); // Purge the wire
return error.UnknownMessage;
Expand Down Expand Up @@ -690,3 +692,41 @@ test "ok_send_cmpctblock_message" {
try std.testing.expect(hint_len > 0);
try std.testing.expect(hint_len == serialized.len);
}

test "ok_send_inv_message" {
const Config = @import("../../config/config.zig").Config;
const ArrayList = std.ArrayList;
const test_allocator = std.testing.allocator;
const InvMessage = protocol.messages.InvMessage;

var list: std.ArrayListAligned(u8, null) = ArrayList(u8).init(test_allocator);
defer list.deinit();

const inventory = try test_allocator.alloc(protocol.InventoryItem, 5);
defer test_allocator.free(inventory);

for (inventory) |*item| {
item.type = 1;
for (&item.hash) |*byte| {
byte.* = 0xab;
}
}

const message = InvMessage{
.inventory = inventory,
};

var received_message = try write_and_read_message(
test_allocator,
&list,
Config.BitcoinNetworkId.MAINNET,
Config.PROTOCOL_VERSION,
message,
) orelse unreachable;
defer received_message.deinit(test_allocator);

switch (received_message) {
.inv => |rm| try std.testing.expect(message.eql(&rm)),
else => unreachable,
}
}
Loading