Skip to content

Commit 579125e

Browse files
committed
[Relax] Handle binary operations between Tensor and PrimValue
Prior to this commit, binary operations were only defined between two tensors. This commit allows binary operations to apply between a tensor and a `relax::PrimValue`. When inferring the output `StructInfo`, binary operations with a `PrimValue` produce the same output as using a 0-d tensor. When legalizing operations containing a `PrimValue`, they are lowered to primitive TIR arguments.
1 parent 5daa303 commit 579125e

File tree

7 files changed

+815
-113
lines changed

7 files changed

+815
-113
lines changed

python/tvm/relax/utils.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
1718
# pylint: disable=invalid-name,too-many-locals
19+
1820
"""Utility functions for Relax"""
21+
1922
import functools
2023
import inspect
24+
import itertools
25+
2126
from typing import Tuple as typing_Tuple
2227
from typing import Any, Callable, List, Dict, Optional, TypeVar
2328

29+
import tvm
2430
from .. import tir
2531
from ..tir import PrimExpr
2632
from ..runtime import String, convert_to_object
@@ -302,9 +308,22 @@ def gen_call_tir_inputs(
302308
out_sinfo, and tir_vars.
303309
"""
304310

305-
def _convert_te_arg(
306-
te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
307-
) -> typing_Tuple[Any, List[te_Tensor]]:
311+
tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
312+
313+
call_tir_args = []
314+
# extra list of tir expression arguments
315+
# that are not covered by Tensor
316+
extra_tir_args_list = []
317+
318+
def _copy_undefined_var(expr: tir.PrimExpr):
319+
def _visit_expr(e: tir.PrimExpr):
320+
if isinstance(e, tir.Var) and e not in tir_var_map:
321+
new_var = tir.Var(e.name, e.dtype)
322+
tir_var_map[e] = new_var
323+
324+
tir.stmt_functor.post_order_visit(expr, _visit_expr)
325+
326+
def _convert_te_arg(te_args: Any) -> Any:
308327
"""Helper function used to convert Relax expressions to TE tensor.
309328
310329
In the common case, the type of te_args is a Relax expression and is converted
@@ -335,18 +354,6 @@ def _convert_te_arg(
335354
A tuple of the converted te_args, and a list of te tensors for each converted
336355
Relax expression
337356
"""
338-
te_args_list = []
339-
# extra list of tir expression arguments
340-
# that are not covered by Tensor
341-
extra_tir_args_list = []
342-
343-
def _copy_undefined_var(expr: tir.PrimExpr):
344-
def _visit_expr(e: tir.PrimExpr):
345-
if isinstance(e, tir.Var) and e not in tir_var_map:
346-
new_var = tir.Var(e.name, e.dtype)
347-
tir_var_map[e] = new_var
348-
349-
tir.stmt_functor.post_order_visit(expr, _visit_expr)
350357

351358
n_tensor = 0
352359

@@ -363,18 +370,23 @@ def _convert_te_arg_helper(arg):
363370
name = chr(ord("A") + n_tensor) if n_tensor < 26 else f"input{n_tensor}"
364371
arg = te_tensor(arg, tir_var_map, name)
365372
n_tensor += 1
366-
te_args_list.append(arg)
373+
call_tir_args.append(arg)
367374
return arg
368375
if isinstance(arg.struct_info, ShapeStructInfo):
369376
assert isinstance(
370377
arg, ShapeExpr
371378
), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
372379
return [_convert_te_arg_helper(val) for val in arg.values]
373-
if (
374-
isinstance(arg.struct_info, PrimStructInfo)
375-
and arg.struct_info.value is not None
376-
):
377-
return _convert_te_arg_helper(arg.struct_info.value)
380+
if isinstance(arg.struct_info, PrimStructInfo):
381+
if arg.struct_info.value is None:
382+
name = arg.name_hint if isinstance(arg, tvm.relax.Var) else "prim_arg"
383+
call_tir_args.append(arg)
384+
return tir.Var(name, arg.struct_info.dtype)
385+
# call_tir_args.append(tir.Var(name, arg.struct_info.dtype))
386+
# return arg
387+
else:
388+
return _convert_te_arg_helper(arg.struct_info.value)
389+
378390
elif isinstance(arg, (list, Array)):
379391
return [_convert_te_arg_helper(x) for x in arg]
380392
elif isinstance(arg, tuple):
@@ -388,35 +400,43 @@ def _convert_te_arg_helper(arg):
388400
elif isinstance(arg, tir.PrimExpr):
389401
_copy_undefined_var(arg)
390402
new_arg = tir.stmt_functor.substitute(arg, tir_var_map)
391-
extra_tir_args_list.append(new_arg)
403+
extra_tir_args_list.append(arg)
392404
return new_arg
393405
elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is None:
394406
return arg
395407
raise TypeError("not supported type in emit_te: {}".format(type(arg)))
396408

397409
new_arg = _convert_te_arg_helper(te_args)
398-
return new_arg, te_args_list, extra_tir_args_list
410+
return new_arg
399411

400412
def _get_unbound_tir_vars(
401413
args: List[te_Tensor], extra_tir_args: List[PrimExpr]
402414
) -> List[tir.Var]:
403415
"""get unbound TIR vars (i.e TIR vars used in the shape but is not
404416
itself a dimension of a shape)"""
417+
405418
bound_vars = set()
406419
used_vars = set()
407420

421+
def _populate_bound_vars(expr):
422+
if isinstance(expr, te_Tensor):
423+
for dim in expr.shape:
424+
_populate_bound_vars(dim)
425+
elif isinstance(expr, tir.Var):
426+
bound_vars.add(expr)
427+
408428
def _populate_used_vars(expr):
409-
if isinstance(expr, tir.Var):
410-
used_vars.add(expr)
429+
if isinstance(expr, te_Tensor):
430+
for dim in expr.shape:
431+
_populate_used_vars(dim)
432+
elif isinstance(expr, tir.PrimExpr):
433+
used_vars.update(tir.analysis.undefined_vars(expr))
411434

412-
for val in extra_tir_args:
413-
tir.stmt_functor.post_order_visit(val, _populate_used_vars)
435+
for arg in itertools.chain(args, extra_tir_args):
436+
_populate_used_vars(arg)
414437

415-
for x in args:
416-
for s in x.shape:
417-
tir.stmt_functor.post_order_visit(s, _populate_used_vars)
418-
if isinstance(s, tir.Var):
419-
bound_vars.add(s)
438+
for arg in args:
439+
_populate_bound_vars(arg)
420440

421441
diff = used_vars - bound_vars
422442
return list(diff)
@@ -448,19 +468,16 @@ def _shape_with_old_tir_var(
448468

449469
primfunc_attrs = kwargs.pop("primfunc_attrs", None)
450470

451-
tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
452-
new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
453-
new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs, tir_var_map)
454-
455-
te_args = te_arg_list + te_kwarg_list
471+
te_args = _convert_te_arg(args)
472+
te_kwargs = _convert_te_arg(kwargs)
456473

457-
te_out = func(*new_args, **new_kwargs)
474+
te_out = func(*te_args, **te_kwargs)
458475
assert isinstance(te_out, te_Tensor) or (
459476
isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, te_Tensor) for t in te_out)
460477
), "only support te.tensor or tuple/list/Array of te.tensor as function output"
461478

462479
outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
463-
unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + tir_kwarg_list)
480+
unbound_tir_vars = _get_unbound_tir_vars([*te_args, *outs], extra_tir_args_list)
464481

465482
inputs = [*te_args] + outs + unbound_tir_vars
466483
tir_func = create_prim_func(inputs, "int64")
@@ -470,7 +487,7 @@ def _shape_with_old_tir_var(
470487

471488
tir_func = tir_func.without_attr("global_symbol")
472489

473-
call_tir_args = [x.op.value for x in te_args]
490+
call_tir_args = [arg.op.value if isinstance(arg, te_Tensor) else arg for arg in call_tir_args]
474491

475492
# Invert the TIR variable mapping, to convert the output shape back
476493
# with old set of variables.

src/relax/op/op_common.h

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -239,52 +239,91 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
239239
const Map<String, Array<String>>& desired_layouts,
240240
const VarLayoutMap& var_layout_map);
241241

242+
/*!
243+
* \brief Get the element dtype from StructInfo
244+
*
245+
* \param sinfo The StructInfo to expect
246+
* \return The inferred element dtype.
247+
* \throw Throw exception if the StructInfo doesn't have an element type.
248+
*/
249+
inline DataType GetElementDType(const StructInfo& sinfo) {
250+
if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
251+
return prim->dtype;
252+
} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
253+
return tensor->dtype;
254+
} else if (sinfo.as<ObjectStructInfoNode>()) {
255+
return DataType::Void();
256+
} else {
257+
LOG(FATAL) << "TypeError: "
258+
<< "Cannot determine element type of " << sinfo;
259+
}
260+
}
261+
242262
/*!
243263
* \brief Infer the output datatype for binary arithmetic operators.
244264
* \param call The context Call to the operator.
245265
* \param ctx The error reporting context.
246-
* \param x1_sinfo The struct info of the first operand
247-
* \param x2_sinfo The struct info of the second operand
266+
* \param lhs_sinfo The struct info of the first operand
267+
* \param rhs_sinfo The struct info of the second operand
248268
* \return The inferred output dtype.
249269
* \throw Throw exception if the dtype of two input TensorStructInfo don’t match
250270
*/
251271
inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx,
252-
const TensorStructInfo& x1_sinfo,
253-
const TensorStructInfo& x2_sinfo) {
254-
if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) {
272+
const StructInfo& lhs_sinfo,
273+
const StructInfo& rhs_sinfo) {
274+
auto lhs_dtype = GetElementDType(lhs_sinfo);
275+
auto rhs_dtype = GetElementDType(rhs_sinfo);
276+
if (lhs_dtype.is_void() || rhs_dtype.is_void()) {
255277
return DataType::Void();
256-
} else if (x1_sinfo->dtype != x2_sinfo->dtype) {
278+
} else if (lhs_dtype != rhs_dtype) {
257279
ctx->ReportFatal(Diagnostic::Error(call)
258-
<< "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype
259-
<< " must be equal for binary operators");
280+
<< "TypeErorr: "
281+
<< "Binary operators must have the same datatype for both operands. "
282+
<< "However, " << call << " uses datatype " << lhs_dtype
283+
<< " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype "
284+
<< rhs_dtype << " on the RHS (StructInfo of " << rhs_sinfo << ").");
260285
}
261-
return x1_sinfo->dtype;
286+
return lhs_dtype;
262287
}
263288

264289
/*!
265290
* \brief Infer the output virtual device for binary arithmetic operators.
266291
* \param call The context Call to the operator.
267292
* \param ctx The error reporting context.
268-
* \param x1_sinfo The struct info of the first operand
269-
* \param x2_sinfo The struct info of the second operand
293+
* \param lhs_sinfo The struct info of the first operand
294+
* \param rhs_sinfo The struct info of the second operand
270295
* \return The inferred output vdevice.
271296
* \throw Throw exception if the vdevice of two input TensorStructInfo don’t match
272297
*/
273298
inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx,
274-
const TensorStructInfo& x1_sinfo,
275-
const TensorStructInfo& x2_sinfo) {
276-
if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) {
277-
return x2_sinfo->vdevice;
299+
const StructInfo& lhs_sinfo,
300+
const StructInfo& rhs_sinfo) {
301+
auto get_vdevice = [&](const StructInfo& sinfo) -> Optional<VDevice> {
302+
if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
303+
return tensor->vdevice;
304+
} else {
305+
return NullOpt;
306+
}
307+
};
308+
309+
auto lhs_vdevice = get_vdevice(lhs_sinfo);
310+
auto rhs_vdevice = get_vdevice(rhs_sinfo);
311+
312+
if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) {
313+
return rhs_vdevice;
278314
}
279-
if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) {
280-
return x1_sinfo->vdevice;
315+
if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) {
316+
return lhs_vdevice;
281317
}
282-
if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
318+
if (lhs_vdevice.value() != rhs_vdevice.value()) {
283319
ctx->ReportFatal(Diagnostic::Error(call)
284-
<< "VDevice " << x1_sinfo->vdevice.value() << " and "
285-
<< x2_sinfo->vdevice.value() << " must be equal for binary operators");
320+
<< "TypeErorr: "
321+
<< "Binary operators with Tensor arguments "
322+
<< "must have the same VDevice for both operands. "
323+
<< "However, " << call << " has a LHS on VDevice " << lhs_vdevice
324+
<< " and a RHS on VDevice " << rhs_vdevice);
286325
}
287-
return x1_sinfo->vdevice;
326+
return lhs_vdevice;
288327
}
289328

290329
/*!

0 commit comments

Comments
 (0)