Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# pylint: disable=redefined-builtin
"""The base Relax operators."""

from typing import Dict, Union, List, Tuple, Optional, Callable


Expand Down Expand Up @@ -71,11 +72,10 @@ def null_value() -> Call:
def call_tir(
gvar: GlobalVar,
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
out_sinfo: Optional[Union[TensorStructInfo, List[TensorStructInfo]]] = None,
tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None,
) -> Call:
"""
Call a tir.prim_func and return the output.
"""Call a tir.PrimFunc and return the output.

Parameters
----------
Expand All @@ -85,23 +85,29 @@ def call_tir(
args : Expr
The input arguments.

out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_tir output.
It should be a single or a list of TensorStructInfo. Each one denotes the
out_sinfo : Optional[Union[TensorStructInfo, List[TensorStructInfo]]]
The structure info of the call_tir output. It should be a
single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.

If `None`, the `out_sinfo` will be inferred from the signature
of `gvar`. Arguments that are accepted by `gvar`, after
`args` and before `tir_vars`, are inferred to be output tensor
arguments.

tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]]
ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used

Returns
-------
ret: Call
A call node for the call_tir operator.

"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
if out_sinfo is not None and not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

if isinstance(tir_vars, (list, tuple)):
Expand Down
138 changes: 137 additions & 1 deletion src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,144 @@ RELAY_REGISTER_OP("relax.call_tir")
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
static Array<TensorStructInfo> InferCallTIROutputStructInfo(Expr func, Tuple args,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that this is not always possible to do such inference. Since it is possible to have tir functions like reshape, where the output shape is being explicitly specified via the destination. For the particular low-level call_tir op. I think it is safer to always ask for the sinfo, then explicitly checks the consistency to avoid error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely agreed that we should check for consistency after generating the IR, and that's something I want to add to the well-formed checker as well. This specific PR would be to avoid inconsistency while generating the IR.

(And if we can't infer the output shape, then the output shape must still be be explicitly provided.)

Optional<Expr> packed_ints) {
auto opt_callee_sinfo = func->struct_info_.as<FuncStructInfo>();
CHECK(opt_callee_sinfo) << "ValueError: "
<< "If the `out_sinfo` argument to `R.call_tir` is omitted, "
<< "then the callee must be annotated with FuncStructInfo, "
<< "from which the return tensor shapes will be inferred.";
auto callee_sinfo = opt_callee_sinfo.value();

CHECK(callee_sinfo->params.defined())
<< "ValueError: "
<< "If the `out_sinfo` argument to `R.call_tir` is omitted, "
<< "then the callee's FuncStructInfo must have known signature, "
<< "but the callee instead has StructInfo " << callee_sinfo;
auto callee_params = callee_sinfo->params.value();

// R.call_tir expects the PrimFunc to have three groups of arguments.
//
// 1. Input arguments that are explicitly provided as Relax arguments.
// 2. Output tensor arguments.
// 3. Shape arguments, represented as `T.int64` in the PrimFunc, and
// as an optional ShapeExpr argument in the `relax::Call` node.
//
// In order to determine the return type of `R.call_tir`, we must
// identify the PrimFunc arguments that will be in group (2).
size_t num_input_arguments = args->fields.size();
size_t num_trailing_int_arguments = 0;
const ShapeStructInfoNode* packed_tuple_sinfo = nullptr;
if (packed_ints) {
auto packed_sinfo = packed_ints.value()->struct_info_;
packed_tuple_sinfo = packed_sinfo.as<ShapeStructInfoNode>();
CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim())
<< "ValueError: "
<< "If the `out_sinfo` argument to `R.call_tir` is omitted, "
<< "and the `tir_vars` argument is present, "
<< "then it must be annotated with ShapeStructInfo "
<< "using a known number of dimensions. "
<< "However, struct info " << packed_sinfo
<< " does not have a known number of dimensions.";
num_trailing_int_arguments = packed_tuple_sinfo->ndim;
} else {
num_trailing_int_arguments = 0;
}

CHECK_LE(num_input_arguments + num_trailing_int_arguments, callee_params.size())
<< "ValueError: "
<< "R.call_tir attempted to call " << func << " using " << num_input_arguments
<< " input arguments and " << num_trailing_int_arguments << " trailing integer arguments. "
<< "However, the callee only accepts " << callee_params.size() << " arguments in total.";

// At this point, the return types are known. However, the shapes
// in `callee_params` may contain dynamic shape parameters that are
// not present in the caller's scope. The `DeriveCallRetStructInfo`
// utility can infer the value of dynamic parameters in
// `FuncStructInfoNode::ret` based on definitions in
// `FuncStructInfoNode::params`, inferring the correct values in the
// caller's scope.
//
// Since the callee of `R.call_tir` is provided with output
// arguments, where `DeriveCallRetStructInfo` requires a callee that
// produces its own outputs, a dummy function signature and
// arguments are used.

auto dummy_callee_sinfo = [&]() -> FuncStructInfo {
Array<StructInfo> dummy_params(callee_params.begin(),
callee_params.begin() + num_input_arguments);

for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size();
i++) {
dummy_params.push_back(callee_params[i]);
}

Array<StructInfo> dummy_ret(callee_params.begin() + num_input_arguments,
callee_params.end() - num_trailing_int_arguments);

return FuncStructInfo(dummy_params, TupleStructInfo(dummy_ret));
}();

auto dummy_args = [&]() -> Array<Expr> {
Array<Expr> dummy_args = args->fields;

for (size_t i = 0; i < num_trailing_int_arguments; i++) {
ICHECK(packed_tuple_sinfo);
PrimStructInfo dummy_arg_sinfo = [&]() {
if (packed_tuple_sinfo->values) {
return PrimStructInfo(packed_tuple_sinfo->values.value()[i]);
} else {
return PrimStructInfo(DataType::Int(64));
}
}();
dummy_args.push_back(Var("dummy_arg", dummy_arg_sinfo));
}

return dummy_args;
}();

auto derived_ret_sinfo = DeriveCallRetStructInfo(
dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), dummy_args),
BlockBuilder::Create(NullOpt));

Array<TensorStructInfo> out_sinfo_list;
auto out_fields = Downcast<TupleStructInfo>(derived_ret_sinfo)->fields;
for (size_t i = 0; i < out_fields.size(); i++) {
auto field_sinfo = out_fields[i];
auto opt_tensor_sinfo = field_sinfo.as<TensorStructInfo>();
CHECK(opt_tensor_sinfo)
<< "TypeError: "
<< "If the `out_sinfo` argument to `R.call_tir` is omitted, "
<< "output tensor arguments are inferred from the number of input arguments. "
<< "However, output " << i << " (corresponding to function parameter "
<< (i + num_input_arguments) << ") has struct info " << field_sinfo
<< ", and is not a tensor.";
auto tensor_sinfo = opt_tensor_sinfo.value();

CHECK(tensor_sinfo->shape.defined())
<< "If the `out_sinfo` argument to `R.call_tir` is omitted, "
<< "output tensor arguments are inferred from the number of input arguments. "
<< "If a TIR function determines a dynamic shape parameter "
<< "based on the shape of an output tensor, "
<< "then this shape inference is impossible. "
<< "Since TIR functions accept dynamic TIR variables may be define"
<< "Please update the `R.call_tir` call with an explicit `out_sinfo` argument.";

out_sinfo_list.push_back(tensor_sinfo);
}
return out_sinfo_list;
}

Expr MakeCallTIR(Expr func, Tuple args, Optional<Array<TensorStructInfo>> opt_out_sinfo_list,
Optional<Expr> packed_ints) {
auto out_sinfo_list = [&]() -> Array<TensorStructInfo> {
if (opt_out_sinfo_list) {
return opt_out_sinfo_list.value();
} else {
return InferCallTIROutputStructInfo(func, args, packed_ints);
}
}();

for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. "
Expand Down
Loading