Skip to content

Commit d92c57d

Browse files
committed
[Runtime] Enable RPCObjectRef return in RPC
This PR enables RPCObjectRef return object similar to the disco transporation. This allows us to do advanced remote debugging when remote vm requires advanced object input like kv cache and shape. To keep the implementation with minRPC(used in some of the limited protocols) forn now, we only support RPCObjectRef for now and do not enable unpacking Shape and String.
1 parent 0b2358c commit d92c57d

File tree

8 files changed

+174
-13
lines changed

8 files changed

+174
-13
lines changed

include/tvm/runtime/object.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ struct TypeIndex {
7272
kRuntimeShapeTuple = 6,
7373
/*! \brief runtime::PackedFunc. */
7474
kRuntimePackedFunc = 7,
75-
/*! \brief runtime::DRef */
75+
/*! \brief runtime::DRef for disco distributed runtime */
7676
kRuntimeDiscoDRef = 8,
77+
/*! \brief runtime::RPCObjectRef */
78+
kRuntimeRPCObjectRef = 9,
7779
// static assignments that may subject to change.
7880
kRuntimeClosure,
7981
kRuntimeADT,

src/runtime/minrpc/minrpc_server.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ class MinRPCExecute : public MinRPCExecInterface {
206206
ret_tcode[1] = kTVMBytes;
207207
ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
208208
TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle)); // NOLINT(*)
209-
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
209+
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle ||
210+
rv_tcode == kTVMObjectHandle) {
210211
ret_tcode[1] = kTVMOpaqueHandle;
211212
ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
212213
} else {
@@ -755,7 +756,17 @@ class MinRPCServer {
755756
}
756757

757758
void ReadObject(int* tcode, TVMValue* value) {
758-
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
759+
// handles RPCObject in minRPC
760+
// NOTE: object needs to be supported by C runtime
761+
// because minrpc's restriction of C only
762+
// we only handle RPCObjectRef
763+
uint32_t type_index;
764+
Read(&type_index);
765+
MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex);
766+
uint64_t object_handle;
767+
Read(&object_handle);
768+
tcode[0] = kTVMObjectHandle;
769+
value[0].v_handle = reinterpret_cast<void*>(object_handle);
759770
}
760771

761772
private:

src/runtime/minrpc/rpc_reference.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ class Object;
3333
/*! \brief The current RPC procotol version. */
3434
constexpr const char* kRPCProtocolVer = "0.8.0";
3535

36+
/*!
37+
* \brief type index of kRuntimeRPCObjectRefTypeIndex
38+
* \note this needs to be kept consistent with runtime/object.h
39+
* but we explicitly declare it here because minrpc needs to be minimum dep
40+
* only c C API
41+
*/
42+
constexpr const int kRuntimeRPCObjectRefTypeIndex = 9;
43+
3644
// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
3745
const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;
3846

src/runtime/rpc/rpc_endpoint.cc

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
175175
for (int i = 0; i < num_args; ++i) {
176176
int tcode = type_codes[i];
177177
if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) {
178-
LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
179-
<< args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is not supported by RPC";
178+
if (!args[i].IsObjectRef<RPCObjectRef>()) {
179+
LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
180+
<< args[i].AsObjectRef<ObjectRef>()->GetTypeKey()
181+
<< " is not supported by RPC";
182+
}
180183
} else if (tcode == kDLDevice) {
181184
DLDevice dev = args[i];
182185
ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel";
@@ -219,14 +222,48 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
219222
this->Write(cdata);
220223
}
221224

222-
void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); }
223-
uint64_t GetObjectBytes(void* obj) {
224-
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
225-
return 0;
225+
void WriteObject(Object* obj) {
226+
// NOTE: for now all remote object are encoded as RPCObjectRef
227+
// follow the same disco protocol in case we would like to upgrade later
228+
//
229+
// Rationale note: Only handle remote object allows the same mechanism to work for minRPC
230+
// which is needed for wasm and other env that goes through C API
231+
if (obj->IsInstance<RPCObjectRefObj>()) {
232+
auto* ref = static_cast<RPCObjectRefObj*>(obj);
233+
this->template Write<uint32_t>(kRuntimeRPCObjectRefTypeIndex);
234+
uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle());
235+
this->template Write<int64_t>(handle);
236+
} else {
237+
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: "
238+
<< obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")";
239+
}
240+
}
241+
uint64_t GetObjectBytes(Object* obj) {
242+
if (obj->IsInstance<RPCObjectRefObj>()) {
243+
return sizeof(uint32_t) + sizeof(int64_t);
244+
} else {
245+
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: "
246+
<< obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")";
247+
}
226248
}
227249

228250
void ReadObject(int* tcode, TVMValue* value) {
229-
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
251+
// NOTE: for now all remote object are encoded as RPCObjectRef
252+
// follow the same disco protocol in case we would like to upgrade later
253+
//
254+
// Rationale note: Only handle remote object allows the same mechanism to work for minRPC
255+
// which is needed for wasm and other env that goes through C API
256+
uint32_t type_index;
257+
this->template Read<uint32_t>(&type_index);
258+
if (type_index == kRuntimeRPCObjectRefTypeIndex) {
259+
uint64_t handle;
260+
this->template Read<uint64_t>(&handle);
261+
tcode[0] = kTVMObjectHandle;
262+
value[0].v_handle = reinterpret_cast<void*>(handle);
263+
} else {
264+
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: "
265+
<< Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")";
266+
}
230267
}
231268

232269
void MessageDone() {

src/runtime/rpc/rpc_local_session.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/runtime/registry.h>
2828

2929
#include <memory>
30+
#include <vector>
3031

3132
namespace tvm {
3233
namespace runtime {
@@ -64,7 +65,8 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu
6465
ret_value_pack[2].v_handle = ret_value_pack[1].v_handle;
6566
ret_tcode_pack[2] = kTVMOpaqueHandle;
6667
encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3));
67-
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
68+
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle ||
69+
rv_tcode == kTVMObjectHandle) {
6870
// MoveToCHost means rv no longer manages the object.
6971
// return handle instead.
7072
rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
@@ -88,7 +90,21 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* a
8890
const FEncodeReturn& encode_return) {
8991
PackedFuncObj* pf = static_cast<PackedFuncObj*>(func);
9092
TVMRetValue rv;
91-
pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);
93+
94+
// unwrap RPCObjectRef in case we are directly using it to call LocalSession
95+
std::vector<TVMValue> values(arg_values, arg_values + num_args);
96+
std::vector<int> type_codes(arg_type_codes, arg_type_codes + num_args);
97+
TVMArgs args(arg_values, arg_type_codes, num_args);
98+
99+
for (int i = 0; i < num_args; ++i) {
100+
if (args[i].IsObjectRef<RPCObjectRef>()) {
101+
RPCObjectRef obj_ref = args[i];
102+
values[i].v_handle = obj_ref->object_handle();
103+
continue;
104+
}
105+
}
106+
107+
pf->CallPacked(TVMArgs(values.data(), type_codes.data(), args.size()), &rv);
92108
this->EncodeReturn(std::move(rv), encode_return);
93109
}
94110

src/runtime/rpc/rpc_module.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ class RPCWrappedFunc : public Object {
157157
}
158158
};
159159

160+
TVM_REGISTER_OBJECT_TYPE(RPCObjectRefObj);
161+
160162
// RPC that represents a remote module session.
161163
class RPCModuleNode final : public ModuleNode {
162164
public:
@@ -294,6 +296,11 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) cons
294296
void* handle = args[1];
295297
auto n = make_object<RPCModuleNode>(handle, sess_);
296298
*rv = Module(n);
299+
} else if (tcode == kTVMObjectHandle) {
300+
ICHECK_EQ(args.size(), 2);
301+
void* handle = args[1];
302+
auto n = make_object<RPCObjectRefObj>(handle, sess_);
303+
*rv = ObjectRef(n);
297304
} else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) {
298305
ICHECK_EQ(args.size(), 3);
299306
DLTensor* tensor = args[1];

src/runtime/rpc/rpc_session.h

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class RPCSession {
142142

143143
/*!
144144
* \brief Free a remote function.
145-
* \param handle The remote handle, can be NDArray/PackedFunc/Module
145+
* \param handle The remote handle, can be NDArray/PackedFunc/Module/Object
146146
* \param type_code The type code of the underlying type.
147147
*/
148148
virtual void FreeHandle(void* handle, int type_code) = 0;
@@ -287,6 +287,55 @@ struct RemoteSpace {
287287
std::shared_ptr<RPCSession> sess;
288288
};
289289

290+
/*!
291+
* \brief Object wrapper that represents a reference to a remote object
292+
*/
293+
class RPCObjectRefObj : public Object {
294+
public:
295+
/*!
296+
* \brief constructor
297+
* \param object_handle handle that points to the remote object
298+
* \param sess The remote session
299+
*/
300+
RPCObjectRefObj(void* object_handle, std::shared_ptr<RPCSession> sess)
301+
: object_handle_(object_handle), sess_(sess) {}
302+
303+
~RPCObjectRefObj() {
304+
if (object_handle_ != nullptr) {
305+
try {
306+
sess_->FreeHandle(object_handle_, kTVMObjectHandle);
307+
} catch (const Error& e) {
308+
// fault tolerance to remote close
309+
}
310+
object_handle_ = nullptr;
311+
}
312+
}
313+
314+
const std::shared_ptr<RPCSession>& sess() const { return sess_; }
315+
316+
void* object_handle() const { return object_handle_; }
317+
318+
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeRPCObjectRef;
319+
static constexpr const char* _type_key = "runtime.RPCObjectRef";
320+
TVM_DECLARE_FINAL_OBJECT_INFO(RPCObjectRefObj, Object);
321+
322+
private:
323+
// The object handle
324+
void* object_handle_{nullptr};
325+
// The local channel
326+
std::shared_ptr<RPCSession> sess_;
327+
};
328+
329+
/*!
330+
* \brief Managed reference to RPCObjectRefObj.
331+
* \sa RPCObjectRefObj
332+
* \note No public constructor is provided as it is not supposed to be directly created by users.
333+
*/
334+
class RPCObjectRef : public ObjectRef {
335+
public:
336+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj);
337+
};
338+
290339
/*!
291340
* \brief Create a Global RPC module that refers to the session.
292341
* \param sess The RPC session of the global module.

tests/python/runtime/test_runtime_rpc.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ def test_rpc_return_ndarray():
426426
ref_count = m("ref_count")
427427
get_elem = m("get_elem")
428428
get_arr_elem = m("get_arr_elem")
429+
429430
# array test
430431
def run_arr_test():
431432
arr = get_arr()
@@ -435,6 +436,36 @@ def run_arr_test():
435436
run_arr_test()
436437

437438

439+
@tvm.testing.requires_rpc
440+
def test_rpc_return_remote_object():
441+
def check(client, is_local):
442+
make_shape = client.get_function("runtime.ShapeTuple")
443+
get_elem = client.get_function("runtime.GetShapeTupleElem")
444+
get_size = client.get_function("runtime.GetShapeTupleSize")
445+
shape = make_shape(2, 3)
446+
assert shape.type_key == "runtime.RPCObjectRef"
447+
assert get_elem(shape, 0) == 2
448+
assert get_elem(shape, 1) == 3
449+
assert get_size(shape) == 2
450+
451+
# start server
452+
server = rpc.Server(key="x1")
453+
client = rpc.connect("127.0.0.1", server.port, key="x1")
454+
check(rpc.LocalSession(), True)
455+
check(client, False)
456+
457+
def check_minrpc():
458+
if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None:
459+
return
460+
# Test minrpc server.
461+
temp = utils.tempdir()
462+
minrpc_exec = temp.relpath("minrpc")
463+
tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, [])
464+
check(rpc.PopenSession(minrpc_exec), False)
465+
466+
check_minrpc()
467+
468+
438469
@tvm.testing.requires_rpc
439470
def test_local_func():
440471
client = rpc.LocalSession()

0 commit comments

Comments
 (0)