@@ -65,25 +65,27 @@ struct CUDAGraphCaptureKeyEqual {
6565 }
6666};
6767
68- /* ! \brief The cache states of a CUDA graph. */
69- class CUDAGraphCache : public Object {
70- public:
71- struct CaptureResult {
72- ~CaptureResult () {
73- if (exec) {
74- CUDA_CALL (cudaGraphExecDestroy (exec));
75- }
68+ /* ! \brief The captured state of a CUDA graph */
69+ struct CUDAGraphCapturedState {
70+ ~CUDAGraphCapturedState () {
71+ if (exec) {
72+ CUDA_CALL (cudaGraphExecDestroy (exec));
7673 }
77- /* !
78- * \brief Tuple of intemediate tensors in the capture func that will be used outside the
79- * capture func
80- */
81- ObjectRef states;
82- /* ! \brief The instantiated cuda graph */
83- cudaGraphExec_t exec = nullptr ;
84- };
74+ }
8575
86- static CUDAGraphCache* Get () { return dmlc::ThreadLocalStore<CUDAGraphCache>::Get (); }
76+ /* !
77+ * \brief Tuple of intemediate tensors in the capture func that will be used outside the
78+ * capture func
79+ */
80+ ObjectRef states;
81+ /* ! \brief The instantiated cuda graph */
82+ cudaGraphExec_t exec = nullptr ;
83+ };
84+
85+ /* ! \brief The VM extension of CUDA graph. */
86+ class CUDAGraphExtensionNode : public VMExtensionNode {
87+ public:
88+ TVM_DECLARE_FINAL_OBJECT_INFO (CUDAGraphExtensionNode, VMExtensionNode);
8789
8890 /* !
8991 * \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
@@ -107,7 +109,7 @@ class CUDAGraphCache : public Object {
107109
108110 cudaStream_t capture_stream;
109111 CUDA_CALL (cudaStreamCreate (&capture_stream));
110- CUDAGraphCache::CaptureResult entry;
112+ CUDAGraphCapturedState entry;
111113
112114 // Set up arguments for the graph execution
113115 Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
@@ -164,12 +166,14 @@ class CUDAGraphCache : public Object {
164166 return alloc_result;
165167 }
166168
169+ static constexpr const char * _type_key = " relax_vm.CUDAGraphExtension" ;
170+
167171 private:
168172 /* !
169173 * \brief The cache of captured cuda graphs. The key is a unique index for the capture function.
170174 * The value is the result of the capture.
171175 */
172- std::unordered_map<CUDAGraphCaptureKey, CaptureResult , CUDAGraphCaptureKeyHash,
176+ std::unordered_map<CUDAGraphCaptureKey, CUDAGraphCapturedState , CUDAGraphCaptureKeyHash,
173177 CUDAGraphCaptureKeyEqual>
174178 capture_cache_;
175179 /* !
@@ -179,29 +183,39 @@ class CUDAGraphCache : public Object {
179183 std::unordered_map<int64_t , ObjectRef> alloc_cache_;
180184};
181185
186+ /* ! Managed reference to CUDAGraphExtensionNode */
187+ class CUDAGraphExtension : public VMExtension {
188+ public:
189+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (CUDAGraphExtension, VMExtension, CUDAGraphExtensionNode);
190+ static CUDAGraphExtension Create () {
191+ auto data_ = make_object<CUDAGraphExtensionNode>();
192+ return CUDAGraphExtension (std::move (data_));
193+ }
194+ };
195+
182196TVM_REGISTER_GLOBAL (" vm.builtin.cuda_graph.run_or_capture" )
183197 .set_body([](TVMArgs args, TVMRetValue* rv) {
184198 ICHECK (args.size () == 5 || args.size () == 4 );
185199 VirtualMachine* vm = VirtualMachine::GetContextPtr (args[0 ]);
200+ auto extension = vm->GetOrCreateExtension <CUDAGraphExtension>();
186201 ObjectRef capture_func = args[1 ];
187202 ObjectRef func_args = args[2 ];
188203 int64_t entry_index = args[3 ];
189204 Optional<ShapeTuple> shape_expr = NullOpt;
190205 if (args.size () == 5 ) {
191206 shape_expr = args[4 ].AsObjectRef <ShapeTuple>();
192207 }
193- CUDAGraphCache* cache = CUDAGraphCache::Get ();
194- *rv = cache->RunOrCapture (vm, capture_func, func_args, entry_index, shape_expr);
208+ *rv = extension->RunOrCapture (vm, capture_func, func_args, entry_index, shape_expr);
195209 });
196210
197211TVM_REGISTER_GLOBAL (" vm.builtin.cuda_graph.get_cached_alloc" )
198212 .set_body([](TVMArgs args, TVMRetValue* rv) {
199213 ICHECK_EQ (args.size (), 3 );
200214 VirtualMachine* vm = VirtualMachine::GetContextPtr (args[0 ]);
215+ auto extension = vm->GetOrCreateExtension <CUDAGraphExtension>();
201216 ObjectRef alloc_func = args[1 ];
202217 int64_t entry_index = args[2 ];
203- CUDAGraphCache* cache = CUDAGraphCache::Get ();
204- *rv = cache->GetCachedAllocation (vm, alloc_func, entry_index);
218+ *rv = extension->GetCachedAllocation (vm, alloc_func, entry_index);
205219 });
206220
207221} // namespace relax_vm
0 commit comments