diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 22d9524b8d764..5f979c31a107b 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -5,8 +5,11 @@ #include #include #include +#include #include #include +#include +#include #include #include #ifdef _WIN32 @@ -17,6 +20,7 @@ # include # include #else +# include # include # include # include @@ -89,6 +93,7 @@ enum rpc_cmd { COPY_TENSOR, GRAPH_COMPUTE, GET_DEVICE_MEMORY, + FREE_ALL_BUFFERS, }; // RPC data structures @@ -736,6 +741,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint // RPC server-side implementation +template +class message_queue { + std::queue queue; + std::mutex mutex; + std::condition_variable cvar; + +public: + message_queue() {} + + void push(const T &value) { + std::unique_lock lock(mutex); + queue.push(value); + lock.unlock(); + cvar.notify_all(); + } + + void pop(T* out) { + std::unique_lock lock(mutex); + cvar.wait(lock, [this] { return queue.size() > 0; }); + *out = queue.front(); + queue.pop(); + } +}; + +struct rpc_response { + std::vector output; + bool status; +}; + +using rpc_response_ptr = std::shared_ptr; +using response_queue = message_queue; +using response_queue_ptr = std::shared_ptr; + +struct rpc_request { + rpc_cmd cmd; + std::vector input; + response_queue_ptr response_queue; +}; +using rpc_request_ptr = std::shared_ptr; +using request_queue = message_queue; +using request_queue_ptr = std::shared_ptr; + class rpc_server { public: rpc_server(ggml_backend_t backend) : backend(backend) {} @@ -752,6 +799,7 @@ class rpc_server { bool copy_tensor(const std::vector & input, std::vector & output); bool graph_compute(const std::vector & input, std::vector & output); + void free_all_buffers(); private: ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * create_node(uint64_t id, @@ -1046,76 +1094,122 @@ bool rpc_server::graph_compute(const std::vector & input, std::vector input; - std::vector output; - uint64_t input_size; - if (!recv_data(sockfd, &input_size, sizeof(input_size))) { - break; - } - input.resize(input_size); - if (!recv_data(sockfd, input.data(), input_size)) { - break; - } + rpc_request_ptr request; + requestq->pop(&request); + rpc_response_ptr response = std::make_shared(); bool ok = true; - switch (cmd) { + switch (request->cmd) { case ALLOC_BUFFER: { - ok = server.alloc_buffer(input, output); + ok = server.alloc_buffer(request->input, response->output); break; } case GET_ALIGNMENT: { - server.get_alignment(output); + server.get_alignment(response->output); break; } case GET_MAX_SIZE: { - server.get_max_size(output); + server.get_max_size(response->output); break; } case BUFFER_GET_BASE: { - ok = server.buffer_get_base(input, output); + ok = server.buffer_get_base(request->input, response->output); break; } case FREE_BUFFER: { - ok = server.free_buffer(input); + ok = server.free_buffer(request->input); break; } case BUFFER_CLEAR: { - ok = server.buffer_clear(input); + ok = server.buffer_clear(request->input); break; } case SET_TENSOR: { - ok = server.set_tensor(input); + ok = server.set_tensor(request->input); break; } case GET_TENSOR: { - ok = server.get_tensor(input, output); + ok = server.get_tensor(request->input, response->output); break; } case COPY_TENSOR: { - ok = server.copy_tensor(input, output); + ok = server.copy_tensor(request->input, response->output); break; } case GRAPH_COMPUTE: { - ok = server.graph_compute(input, output); + ok = server.graph_compute(request->input, response->output); + break; + } + case GET_DEVICE_MEMORY: { + break; + } + case FREE_ALL_BUFFERS: { + server.free_all_buffers(); + continue; + } + default: { + fprintf(stderr, "Unknown command: %d\n", request->cmd); + ok = false; + } + } + response->status = ok; + request->response_queue->push(response); + } +} + +static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) { + auto responseq = std::make_shared(); + while (true) { + uint8_t cmd; + if (!recv_data(sockfd, &cmd, 1)) { + break; + } + auto request = std::make_shared(); + request->cmd = (rpc_cmd)cmd; + request->response_queue = responseq; + uint64_t input_size; + if (!recv_data(sockfd, &input_size, sizeof(input_size))) { + break; + } + request->input.resize(input_size); + if (!recv_data(sockfd, request->input.data(), input_size)) { + break; + } + bool ok = true; + auto response = std::make_shared(); + switch (cmd) { + case ALLOC_BUFFER: + case GET_ALIGNMENT: + case GET_MAX_SIZE: + case BUFFER_GET_BASE: + case FREE_BUFFER: + case BUFFER_CLEAR: + case SET_TENSOR: + case GET_TENSOR: + case COPY_TENSOR: + case GRAPH_COMPUTE: { + requestq->push(request); + responseq->pop(&response); break; } case GET_DEVICE_MEMORY: { // output serialization format: | free (8 bytes) | total (8 bytes) | - output.resize(2*sizeof(uint64_t), 0); - memcpy(output.data(), &free_mem, sizeof(free_mem)); - memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); + response->output.resize(2*sizeof(uint64_t), 0); + memcpy(response->output.data(), &free_mem, sizeof(free_mem)); + memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); break; } default: { @@ -1126,17 +1220,29 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre if (!ok) { break; } - uint64_t output_size = output.size(); + uint64_t output_size = response->output.size(); if (!send_data(sockfd, &output_size, sizeof(output_size))) { break; } - if (!send_data(sockfd, output.data(), output_size)) { + if (!send_data(sockfd, response->output.data(), output_size)) { break; } } + auto request = std::make_shared(); + request->cmd = FREE_ALL_BUFFERS; + requestq->push(request); } void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { +#ifndef _WIN32 + // prevent SIGPIPE when writing to closed socket + signal(SIGPIPE, SIG_IGN); +#endif + auto requestq = std::make_shared(); + std::thread backend_thread = std::thread([=] { + process_requests(backend, requestq); + }); + std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -1164,7 +1270,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free return; } printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); - rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); + rpc_serve_client(requestq, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); } #ifdef _WIN32