Skip to content

Commit bd361b9

Browse files
[RELAY] [AST] Add virtual_device as a first class field in Relay (#9641)
1 parent 3b28216 commit bd361b9

File tree

7 files changed

+138
-44
lines changed

7 files changed

+138
-44
lines changed

include/tvm/ir/expr.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ namespace tvm {
3939

4040
using tvm::runtime::String;
4141

42+
// Forward-declare SEScope to avoid circular imports.
43+
class SEScope;
44+
4245
/*!
4346
* \brief Base type of all the expressions.
4447
* \sa Expr
@@ -165,6 +168,29 @@ class RelayExprNode : public BaseExprNode {
165168
template <typename TTypeNode>
166169
inline const TTypeNode* type_as() const;
167170

171+
/*!
172+
* \brief The virtual device (SEScope) for this node (the result of device planning).
173+
* For first-order expressions (non functions), this describes where the result of evaluating the
174+
* expression should be stored. Note that currently, all composite first-order values (tuples,
175+
* references, ADTs) must be stored on the same virtual device. This means that it is not possible
176+
* to store two tuple fields on different devices, so we only need one virtual device for these
177+
* types.
178+
*
179+
* For expressions that have the function type, the virtual device describes where the result of
180+
* the call to the function or closure is stored (instead of where the function itself is stored).
181+
* The SEScope's Target field describes how the body of the function should be compiled.
182+
*
183+
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
184+
* import.
185+
*/
186+
mutable ObjectRef virtual_device_;
187+
188+
/*!
189+
* \return The virtual device (SEScope).
190+
* If the virtual device is not defined, returns SEScope::FullyUnconstrained().
191+
*/
192+
SEScope virtual_device() const;
193+
168194
static constexpr const char* _type_key = "RelayExpr";
169195
static constexpr const uint32_t _type_child_slots = 22;
170196
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);

include/tvm/relay/expr.h

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/ir/expr.h>
2929
#include <tvm/ir/module.h>
3030
#include <tvm/ir/op.h>
31+
#include <tvm/target/se_scope.h>
3132

3233
#include <functional>
3334
#include <stack>
@@ -151,10 +152,14 @@ class Tuple : public Expr {
151152
* \param tuple The tuple to copy
152153
* \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields =
153154
* tuple->fields.
154-
* \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span.
155+
* \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none,
156+
* ret_tuple->virtual_device = tuple->virtual_device.
157+
* \param opt_span The (optional) span for the copied tuple. If none,
158+
* ret_tuple->span = tuple->span.
155159
*/
156160
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
157-
Optional<Span> opt_span = Optional<Span>(nullptr));
161+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
162+
Optional<Span> opt_span = Optional<Span>());
158163

159164
/*!
160165
* \brief Local variables used in the let expression.
@@ -240,14 +245,17 @@ class Var : public Expr {
240245
* \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid.
241246
* \param opt_type_annotation The (optional) type_annotation for the copied var. If none,
242247
* ret_var->type_annotation = var->type_annotation.
248+
* \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none,
249+
* ret_tuple->virtual_device = tuple->virtual_device.
243250
* \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span.
244-
* \return If all properties are null or the same as the property in the input var
245-
* (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise,
246-
* we return a copy of call with the different fields overwritten. (i.e., if
247-
* opt_vid.value() != var->vid, then ret_var->vid = opt_.value()).
251+
* \return If all properties are null or the same as the property in the input var (i.e., opt_vid is
252+
* null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of
253+
* call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then
254+
* ret_var->vid = opt_.value()).
248255
*/
249256
Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(),
250257
Optional<Type> opt_type_annotation = Optional<Type>(),
258+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
251259
Optional<Span> opt_span = Optional<Span>());
252260

253261
/*!
@@ -362,16 +370,19 @@ class Call : public Expr {
362370
* call->attrs.
363371
* \param opt_type_args The (optional) type args for the copied call. If none,
364372
* ret_call->type_args = call->type_args.
373+
* \param opt_virtual_device The (optional) virtual_device for the copied call. If none,
374+
* ret_call->virtual_device = call->virtual_device.
365375
* \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span.
366-
* \return If all properties are null or the same as the property in the input call
367-
* (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we
368-
* return a copy of call with the different fields overwritten. (i.e., if opt_op.value() !=
369-
* call->op, then ret_call->op = opt_op.value()).
376+
* \return If all properties are null or the same as the property in the input call (i.e., opt_op is
377+
* null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of
378+
* call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then
379+
* ret_call->op = opt_op.value()).
370380
*/
371381
Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
372382
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
373383
Optional<Attrs> opt_attrs = Optional<Attrs>(),
374384
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(),
385+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
375386
Optional<Span> opt_span = Optional<Span>());
376387

377388
/*!
@@ -456,6 +467,8 @@ class Let : public Expr {
456467
* \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op.
457468
* \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args.
458469
* \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs.
470+
* \param opt_virtual_device The (optional) virtual_device for the copied let. If none,
471+
* ret_let->virtual_device = let->virtual_device.
459472
* \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span.
460473
* \return If all properties are null or the same as the property in the input let (i.e., opt_var is
461474
* null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of
@@ -465,6 +478,7 @@ class Let : public Expr {
465478
Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(),
466479
Optional<Expr> opt_value = Optional<Expr>(),
467480
Optional<Expr> opt_body = Optional<Expr>(),
481+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
468482
Optional<Span> opt_span = Optional<Span>());
469483

470484
/*!
@@ -539,17 +553,19 @@ class If : public Expr {
539553
* ret_if->true_branch = ret_if->false_branch.
540554
* \param opt_false_branch The (optional) false_branch
541555
* for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch.
542-
* \param opt_span
543-
* The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span.
544-
* \return If all
545-
* properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or
546-
* opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of
547-
* if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then
548-
* ret_if->cond = opt_cond.value()).
556+
* \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none,
557+
* ret_if->virtual_device = if_expr->virtual_device.
558+
* \param opt_span The (optional) span for the copied if_expr. If none,
559+
* ret_if->span = if_expr->span.
560+
* \return If all properties are null or the same as the property in
561+
* the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we
562+
* return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten.
563+
* (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()).
549564
*/
550565
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
551566
Optional<Expr> opt_true_branch = Optional<Expr>(),
552567
Optional<Expr> opt_false_branch = Optional<Expr>(),
568+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
553569
Optional<Span> opt_span = Optional<Span>());
554570

555571
/*! \brief Get index-th field out of a tuple. */
@@ -603,8 +619,9 @@ class TupleGetItem : public Expr {
603619
* ret_tuple_get_item->tuple = tuple_get_item->tuple.
604620
* \param opt_index The (optional) index for the copied tuple_get_item. If none,
605621
* ret_tuple_get_item->index = tuple_get_item->index.
606-
* \param
607-
* opt_span The (optional) span for the copied tuple_get_item. If none,
622+
* \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item.
623+
* If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device.
624+
* \param opt_span The (optional) span for the copied tuple_get_item. If none,
608625
* ret_tuple_get_item->span = tuple_get_item->span.
609626
* \return If all properties are null or the same as the property in the input tuple_get_item
610627
* (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return
@@ -614,6 +631,7 @@ class TupleGetItem : public Expr {
614631
*/
615632
TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
616633
Optional<Integer> opt_index = Optional<Integer>(),
634+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
617635
Optional<Span> opt_span = Optional<Span>());
618636

619637
/*! \brief Create a new Reference out of initial value. */
@@ -663,6 +681,8 @@ class RefCreate : public Expr {
663681
* \param ref_create The ref_create to copy.
664682
* \param opt_value The (optional) value for the copied ref_create. If none,
665683
* ret_ref_create->value = ref_create->value.
684+
* \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none,
685+
* ret_ref_create->virtual_device = ref_create->virtual_device.
666686
* \param opt_span The (optional) span for the copied ref_create. If none,
667687
* ret_ref_create->span = ref_create->span.
668688
* \return If all properties are null or the same as the property in the input ref_create
@@ -672,6 +692,7 @@ class RefCreate : public Expr {
672692
* ret_ref_create->value = opt_value.value()).
673693
*/
674694
RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(),
695+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
675696
Optional<Span> opt_span = Optional<Span>());
676697

677698
/*! \brief Get value out of Reference. */
@@ -720,15 +741,18 @@ class RefRead : public Expr {
720741
* \param ref_read The ref_read to copy.
721742
* \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref =
722743
* ref_read->ref.
723-
* \param opt_span
724-
* The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span.
725-
* \return If all properties are null or the same as the property in the input ref_read
726-
* (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read.
727-
* Otherwise, we return a copy of ref_read with the different fields overwritten.
728-
* (i.e., if opt_ref.value() != ref_read->ref, then
729-
* ret_ref_read->ref = opt_ref.value()).
744+
* \param opt_virtual_device
745+
* The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device =
746+
* ref_read->virtual_device.
747+
* \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span =
748+
* ref_read->span.
749+
* \return If all properties are null or the same as the property in the input
750+
* ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return
751+
* ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e.,
752+
* if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()).
730753
*/
731754
RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(),
755+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
732756
Optional<Span> opt_span = Optional<Span>());
733757

734758
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
@@ -784,16 +808,19 @@ class RefWrite : public Expr {
784808
* ret_ref_write->ref = ref_write->ref.
785809
* \param opt_value The (optional) value for the copied ref_write. If none,
786810
* ret_ref_write->value = ref_write->value.
787-
* \param opt_span
788-
* The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span.
789-
* \return If all properties are null or the same as the property in the input ref_write
790-
* (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write.
791-
* Otherwise, we return a copy of ref_write with the different fields overwritten.
792-
* (i.e., if ref_write.value() != ref_write->ref, then
793-
* ret_ref_write->ref = opt_ref.value()).
811+
* \param opt_virtual_device
812+
* The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device =
813+
* ref_write->virtual_device.
814+
* \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span =
815+
* ref_write->span.
816+
* \return If all properties are null or the same as the property in the input ref_write (i.e.,
817+
* opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise,
818+
* we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value()
819+
* != ref_write->ref, then ret_ref_write->ref = opt_ref.value()).
794820
*/
795821
RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(),
796822
Optional<Expr> opt_value = Optional<Expr>(),
823+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
797824
Optional<Span> opt_span = Optional<Span>());
798825

799826
/*!

include/tvm/relay/function.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class Function : public BaseFunc {
134134
* \param opt_attrs
135135
* The (optional) attributes for the copied function. If none,
136136
* ret_function->attrs = function->attrs.
137+
* \param opt_virtual_device The (optional) virtual_device for the copied function. If none,
138+
* ret_function->virtual_device = function->virtual_device.
137139
* \param opt_span The (optional) span for the copied function. If none,
138140
* ret_function->span = function->span.
139141
* \return If all properties are null or the same as the property in the input function
@@ -146,6 +148,7 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
146148
Optional<Type> opt_ret_type = Optional<Type>(),
147149
Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
148150
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
151+
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
149152
Optional<Span> opt_span = Optional<Span>());
150153

151154
/*

rust/tvm-sys/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ enumn = "^0.1"
8585
[build-dependencies]
8686
bindgen = { version="0.57", default-features = false, features = ["runtime"] }
8787
anyhow = "^1.0"
88-
tvm-build = "0.2.1"
88+
tvm-build = "0.2.4"

rust/tvm/src/ir/relay/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ pub mod attrs;
4040
pub struct ExprNode {
4141
pub base: BaseExprNode,
4242
pub checked_type: Type,
43+
pub virtual_device: ObjectRef,
4344
}
4445

4546
impl ExprNode {
4647
pub fn base<T: IsObject>(span: Span) -> ExprNode {
4748
ExprNode {
4849
base: BaseExprNode::base::<T>(span.clone()),
4950
checked_type: Type::null(),
51+
virtual_device: ObjectRef::null(),
5052
}
5153
}
5254
}

0 commit comments

Comments
 (0)