@@ -96,27 +96,37 @@ static ggml_guid_t ggml_backend_rpc_guid() {
9696 return &guid;
9797}
9898
99- struct ggml_backend_rpc_buffer_type_context {
99+ struct rpc_backend {
100+ int ref_count;
101+ std::string endpoint;
100102 std::shared_ptr<socket_t > sock;
103+ ggml_backend_t backend;
104+ };
105+
106+ using rpc_backend_ptr = std::shared_ptr<rpc_backend>;
107+
108+ struct ggml_backend_rpc_buffer_type_context {
109+ std::shared_ptr<rpc_backend> back;
101110 std::string name;
102111 size_t alignment;
103112 size_t max_size;
104113};
105114
106115struct ggml_backend_rpc_context {
107- std::string endpoint;
108116 std::string name;
109- std::shared_ptr<socket_t > sock ;
117+ std::shared_ptr<rpc_backend> back ;
110118 ggml_backend_buffer_type_t buft;
111119};
112120
113121struct ggml_backend_rpc_buffer_context {
114- std::shared_ptr<socket_t > sock ;
122+ std::shared_ptr<rpc_backend> back ;
115123 std::unordered_map<ggml_backend_buffer_t , void *> base_cache;
116124 uint64_t remote_ptr;
117125 std::string name;
118126};
119127
128+ static std::unordered_map<std::string, rpc_backend_ptr> instances;
129+
120130// RPC helper functions
121131
122132static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
@@ -231,14 +241,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231241 return true ;
232242}
233243
234- static bool parse_endpoint (const char * endpoint, std::string & host, int & port) {
235- std::string str (endpoint);
236- size_t pos = str.find (' :' );
244+ static bool parse_endpoint (const std::string & endpoint, std::string & host, int & port) {
245+ size_t pos = endpoint.find (' :' );
237246 if (pos == std::string::npos) {
238247 return false ;
239248 }
240- host = str .substr (0 , pos);
241- port = std::stoi (str .substr (pos + 1 ));
249+ host = endpoint .substr (0 , pos);
250+ port = std::stoi (endpoint .substr (pos + 1 ));
242251 return true ;
243252}
244253
@@ -273,6 +282,22 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273282
274283// RPC client-side implementation
275284
285+ static void free_rpc_backend (rpc_backend_ptr rpc_back) {
286+ ggml_backend_t backend = rpc_back->backend ;
287+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
288+ std::string endpoint = rpc_back->endpoint ;
289+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft ->context ;
290+ GGML_PRINT_DEBUG (" [%s] closing connection to %s\n " , __func__, endpoint.c_str ());
291+ delete buft_ctx;
292+ delete rpc_ctx->buft ;
293+ delete rpc_ctx;
294+ delete backend;
295+ instances.erase (endpoint);
296+ #ifdef _WIN32
297+ WSACleanup ();
298+ #endif
299+ }
300+
276301GGML_CALL static const char * ggml_backend_rpc_buffer_get_name (ggml_backend_buffer_t buffer) {
277302 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
278303 return ctx->name .c_str ();
@@ -285,9 +310,13 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
285310 uint64_t remote_ptr = ctx->remote_ptr ;
286311 memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
287312 std::vector<uint8_t > output;
288- bool status = send_rpc_cmd (ctx->sock , FREE_BUFFER, input, output);
313+ bool status = send_rpc_cmd (ctx->back -> sock , FREE_BUFFER, input, output);
289314 GGML_ASSERT (status);
290315 GGML_ASSERT (output.empty ());
316+ ctx->back ->ref_count --;
317+ if (ctx->back ->ref_count == 0 ) {
318+ free_rpc_backend (ctx->back );
319+ }
291320 delete ctx;
292321}
293322
@@ -301,7 +330,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
301330 uint64_t remote_ptr = ctx->remote_ptr ;
302331 memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
303332 std::vector<uint8_t > output;
304- bool status = send_rpc_cmd (ctx->sock , BUFFER_GET_BASE, input, output);
333+ bool status = send_rpc_cmd (ctx->back -> sock , BUFFER_GET_BASE, input, output);
305334 GGML_ASSERT (status);
306335 GGML_ASSERT (output.size () == sizeof (uint64_t ));
307336 // output serialization format: | base_ptr (8 bytes) |
@@ -360,7 +389,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
360389 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
361390 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
362391 std::vector<uint8_t > output;
363- bool status = send_rpc_cmd (ctx->sock , SET_TENSOR, input, output);
392+ bool status = send_rpc_cmd (ctx->back -> sock , SET_TENSOR, input, output);
364393 GGML_ASSERT (status);
365394}
366395
@@ -374,7 +403,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
374403 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
375404 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
376405 std::vector<uint8_t > output;
377- bool status = send_rpc_cmd (ctx->sock , GET_TENSOR, input, output);
406+ bool status = send_rpc_cmd (ctx->back -> sock , GET_TENSOR, input, output);
378407 GGML_ASSERT (status);
379408 GGML_ASSERT (output.size () == size);
380409 // output serialization format: | data (size bytes) |
@@ -387,7 +416,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
387416 ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context ;
388417 ggml_backend_buffer_t dst_buffer = dst->buffer ;
389418 ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context ;
390- if (src_ctx->sock != dst_ctx->sock ) {
419+ if (src_ctx->back != dst_ctx->back ) {
391420 return false ;
392421 }
393422 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
@@ -399,7 +428,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
399428 memcpy (input.data (), &rpc_src, sizeof (rpc_src));
400429 memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
401430 std::vector<uint8_t > output;
402- bool status = send_rpc_cmd (ctx->sock , COPY_TENSOR, input, output);
431+ bool status = send_rpc_cmd (ctx->back -> sock , COPY_TENSOR, input, output);
403432 GGML_ASSERT (status);
404433 // output serialization format: | result (1 byte) |
405434 GGML_ASSERT (output.size () == 1 );
@@ -414,7 +443,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
414443 memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
415444 memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
416445 std::vector<uint8_t > output;
417- bool status = send_rpc_cmd (ctx->sock , BUFFER_CLEAR, input, output);
446+ bool status = send_rpc_cmd (ctx->back -> sock , BUFFER_CLEAR, input, output);
418447 GGML_ASSERT (status);
419448}
420449
@@ -442,7 +471,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442471 std::vector<uint8_t > input (input_size, 0 );
443472 memcpy (input.data (), &size, sizeof (size));
444473 std::vector<uint8_t > output;
445- bool status = send_rpc_cmd (buft_ctx->sock , ALLOC_BUFFER, input, output);
474+ bool status = send_rpc_cmd (buft_ctx->back -> sock , ALLOC_BUFFER, input, output);
446475 GGML_ASSERT (status);
447476 GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
448477 // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,8 +482,9 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453482 if (remote_ptr != 0 ) {
454483 ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
455484 ggml_backend_rpc_buffer_interface,
456- new ggml_backend_rpc_buffer_context{buft_ctx->sock , {}, remote_ptr, " RPC" },
485+ new ggml_backend_rpc_buffer_context{buft_ctx->back , {}, remote_ptr, " RPC" },
457486 remote_size);
487+ buft_ctx->back ->ref_count ++;
458488 return buffer;
459489 } else {
460490 return nullptr ;
@@ -508,7 +538,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508538 }
509539 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
510540 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
511- return buft_ctx->sock == rpc_ctx->sock ;
541+ return buft_ctx->back == rpc_ctx->back ;
512542}
513543
514544static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +551,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521551 /* .is_host = */ NULL ,
522552};
523553
524-
525554GGML_CALL static const char * ggml_backend_rpc_name (ggml_backend_t backend) {
526555 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
527556
@@ -530,11 +559,10 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530559
531560GGML_CALL static void ggml_backend_rpc_free (ggml_backend_t backend) {
532561 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
533- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft ->context ;
534- delete buft_ctx;
535- delete rpc_ctx->buft ;
536- delete rpc_ctx;
537- delete backend;
562+ rpc_ctx->back ->ref_count --;
563+ if (rpc_ctx->back ->ref_count == 0 ) {
564+ free_rpc_backend (rpc_ctx->back );
565+ }
538566}
539567
540568GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type (ggml_backend_t backend) {
@@ -590,7 +618,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590618 std::vector<uint8_t > input;
591619 serialize_graph (cgraph, input);
592620 std::vector<uint8_t > output;
593- bool status = send_rpc_cmd (rpc_ctx->sock , GRAPH_COMPUTE, input, output);
621+ bool status = send_rpc_cmd (rpc_ctx->back -> sock , GRAPH_COMPUTE, input, output);
594622 GGML_ASSERT (status);
595623 GGML_ASSERT (output.size () == 1 );
596624 return (enum ggml_status)output[0 ];
@@ -624,17 +652,9 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624652 /* .event_synchronize = */ NULL ,
625653};
626654
627- static std::unordered_map<std::string, ggml_backend_t > instances;
628-
629- GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type (const char * endpoint) {
630- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
631- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type (backend) : nullptr ;
632- }
633-
634- GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
635- std::string endpoint_str (endpoint);
636- if (instances.find (endpoint_str) != instances.end ()) {
637- return instances[endpoint_str];
655+ static rpc_backend_ptr create_rpc_backend (const std::string & endpoint) {
656+ if (instances.find (endpoint) != instances.end ()) {
657+ return instances[endpoint];
638658 }
639659#ifdef _WIN32
640660 {
@@ -645,7 +665,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
645665 }
646666 }
647667#endif
648- fprintf (stderr, " Connecting to %s\n " , endpoint);
668+ fprintf (stderr, " Connecting to %s\n " , endpoint. c_str () );
649669 std::string host;
650670 int port;
651671 if (!parse_endpoint (endpoint, host, port)) {
@@ -657,11 +677,12 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
657677 }
658678 size_t alignment = get_alignment (sock);
659679 size_t max_size = get_max_size (sock);
680+ auto rpc_back = std::make_shared<rpc_backend>();
660681 ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661- /* .sock = */ sock ,
662- /* .name = */ " RPC" + std::to_string (sock->fd ),
682+ /* .back = */ rpc_back ,
683+ /* .name = */ " RPC" + std::to_string (sock->fd ),
663684 /* .alignment = */ alignment,
664- /* .max_size = */ max_size
685+ /* .max_size = */ max_size
665686 };
666687
667688 ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
@@ -670,19 +691,37 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
670691 };
671692
672693 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673- /* .endpoint = */ endpoint,
674694 /* .name = */ " RPC" + std::to_string (sock->fd ),
675- /* .sock = */ sock ,
695+ /* .back = */ rpc_back ,
676696 /* .buft = */ buft
677697 };
678698
679- instances[endpoint] = new ggml_backend {
699+ ggml_backend_t backend = new ggml_backend {
680700 /* .guid = */ ggml_backend_rpc_guid (),
681701 /* .interface = */ ggml_backend_rpc_interface,
682702 /* .context = */ ctx
683703 };
704+ rpc_back->sock = sock;
705+ rpc_back->endpoint = endpoint;
706+ rpc_back->backend = backend;
707+ rpc_back->ref_count = 0 ;
708+ instances[endpoint] = rpc_back;
709+ return rpc_back;
710+ }
684711
685- return instances[endpoint];
712+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type (const char * endpoint) {
713+ auto rpc_back = create_rpc_backend (endpoint);
714+ return rpc_back != nullptr ? ggml_backend_rpc_get_default_buffer_type (rpc_back->backend ) : nullptr ;
715+ }
716+
717+ GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
718+ std::string endpoint_str (endpoint);
719+ auto rpc_back = create_rpc_backend (endpoint_str);
720+ if (rpc_back == nullptr ) {
721+ return nullptr ;
722+ }
723+ rpc_back->ref_count ++;
724+ return rpc_back->backend ;
686725}
687726
688727GGML_API GGML_CALL bool ggml_backend_is_rpc (ggml_backend_t backend) {
@@ -706,14 +745,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706745}
707746
708747GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory (const char * endpoint, size_t * free, size_t * total) {
709- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
710- if (backend == nullptr ) {
748+ auto rpc_back = create_rpc_backend (endpoint);
749+ if (rpc_back == nullptr ) {
711750 *free = 0 ;
712751 *total = 0 ;
713752 return ;
714753 }
715- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context ;
716- get_device_memory (ctx->sock , free, total);
754+ get_device_memory (rpc_back->sock , free, total);
717755}
718756
719757// RPC server-side implementation
0 commit comments