Skip to content

Commit

Permalink
rpc : refactor backend
Browse files Browse the repository at this point in the history
Use structs for RPC request/response messages
  • Loading branch information
rgerganov committed Oct 17, 2024
1 parent becfd38 commit 4631edc
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions ggml/src/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct socket_t {
};

// ggml_tensor is serialized into rpc_tensor
#pragma pack(push, 1)
#pragma pack(1)
struct rpc_tensor {
uint64_t id;
uint32_t type;
Expand Down Expand Up @@ -96,6 +96,17 @@ enum rpc_cmd {
RPC_CMD_COUNT,
};

#pragma pack(1)
struct request_alloc_buffer {
uint64_t size;
};

#pragma pack(1)
struct response_alloc_buffer {
uint64_t remote_ptr;
uint64_t remote_size;
};

// RPC data structures

static ggml_guid_t ggml_backend_rpc_guid() {
Expand Down Expand Up @@ -252,30 +263,31 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int

// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
uint8_t cmd_byte = cmd;
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
return false;
}
uint64_t input_size = input.size();
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
return false;
}
if (!send_data(sock->fd, input.data(), input.size())) {
if (!send_data(sock->fd, input, input_size)) {
return false;
}
uint64_t output_size;
if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
// even if we do, we can skip sending output_size from the server for commands with known output size
uint64_t out_size;
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
return false;
}
if (output_size == 0) {
output.clear();
return true;
}
output.resize(output_size);
if (!recv_data(sock->fd, output.data(), output_size)) {
if (out_size != output_size) {
return false;
}
if (output_size > 0) {
if (!recv_data(sock->fd, output, output_size)) {
return false;
}
}
return true;
}

Expand Down Expand Up @@ -484,25 +496,15 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t

static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
// input serialization format: | size (8 bytes) |
int input_size = sizeof(uint64_t);
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
request_alloc_buffer request = {size};
response_alloc_buffer response;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
uint64_t remote_ptr;
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
size_t remote_size;
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
if (remote_ptr != 0) {
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
if (response.remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
remote_size);
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
response.remote_size);
return buffer;
} else {
return nullptr;
Expand Down

0 comments on commit 4631edc

Please sign in to comment.