@@ -66,6 +66,7 @@ enum class Opcode {
6666 AllocStorage = 16U ,
6767 ShapeOf = 17U ,
6868 ReshapeTensor = 18U ,
69+ DeviceCopy = 19U ,
6970};
7071
7172/* ! \brief A single virtual machine instruction.
@@ -196,6 +197,8 @@ struct Instruction {
196197 Index alignment;
197198 /* ! \brief The hint of the dtype. */
198199 DLDataType dtype_hint;
200+ /* ! \brief The device type of the allocation. */
201+ Index device_type;
199202 } alloc_storage;
200203 struct /* ShapeOf Operands */ {
201204 RegName tensor;
@@ -204,6 +207,13 @@ struct Instruction {
204207 RegName tensor;
205208 RegName newshape;
206209 } reshape_tensor;
210+ struct /* DeviceCopy Operands */ {
211+ RegName src;
212+ /* ! \brief The source device type. */
213+ Index src_device_type;
214+ /* ! \brief The destination device type. */
215+ Index dst_device_type;
216+ };
207217 };
208218
209219 /* !
@@ -341,11 +351,12 @@ struct Instruction {
341351 * \param size The size of the allocation.
342352 * \param alignment The allocation's alignment.
343353 * \param dtype_hint The data type hint for the allocator.
354+ * \param device_type The device type for the allocator.
344355 * \param dst The destination to place the storage.
345356 * \return The alloc storage instruction.
346357 */
347358 static Instruction AllocStorage (RegName size, Index alignment, DLDataType dtype_hint,
348- RegName dst);
359+ Index device_type, RegName dst);
349360 /* !
350361 * \brief Get the shape of an input tensor.
351362 * \param tensor The input tensor.
@@ -361,6 +372,16 @@ struct Instruction {
361372 * \return The reshape tensor instruction.
362373 */
363374 static Instruction ReshapeTensor (RegName tensor, RegName newshape, RegName dst);
375+ /* !
376+ * \brief Copy tensor cross different devices.
377+ * \param src The source register.
378+ * \param src_device_type The device type of the tensor for the source register.
379+ * \param dst_device_type The device type of the tensor ofr the destination register.
380+ * \param dst The destination register to store the copied tensor.
381+ * \return The device copy instruction.
382+ */
383+ static Instruction DeviceCopy (RegName src, Index src_device_type, Index dst_device_type,
384+ RegName dst);
364385
365386 Instruction ();
366387 Instruction (const Instruction& instr);
0 commit comments