diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 720e3dd3b429..fbca48f0ee5e 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1629,6 +1629,75 @@ def tensor_ir_op( ) +def tensor_ir_inplace_op( + func: _tir.PrimFunc, + name_hint: str, + args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]], + inplace_indices: Union[int, List[int]], + out: OutType, +) -> OutType: + """Create a `call_tir_inplace` binding with given PrimFunc + + Parameters + ---------- + func : _tir.PrimFunc + The PrimFunc to call. + + name_hint : str + Name hint. + + args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]] + The arguments to pass to the PrimFunc. + + inplace_indices : Union[int, List[int]] + Specify which arguments should be used for in-place computations. + If `inplace_indices` is a single integer, it will be made into a singleton list. + Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output + will be an alias of `args[j]`. + If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. + At least one member of `inplace_indices` must not be -1. + + out : Union[Tensor, List[Tensor]] + The output tensors. + + Returns + ------- + result : Tensor + The result tensor + """ + from tvm import relax as rx # pylint: disable=import-outside-toplevel + + call_tir_args, tir_vars = [], [] + if not isinstance(args, (tuple, list)): + args = [args] + + for arg in args: + if isinstance(arg, Tensor): + call_tir_args.append(arg._expr) + elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)): + tir_vars.append(arg) + else: + raise TypeError( + "Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or" + f" PrimExpr, but got {type(arg)}" + ) + + if isinstance(out, Tensor): + out_sinfo = [out._expr.struct_info] + else: + out_sinfo = [x._expr.struct_info for x in out] + + bb = BlockBuilder.current() + global_var = bb.add_func(func, name_hint) + + return wrap_nested( + bb.emit( + rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars) + ), + name=name_hint, + ) + + def extern( name: str, args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]], diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index b363dc6952d8..92235ffb479e 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -198,13 +198,13 @@ def call_tir_inplace( args : Expr The input arguments. - input_indices : Union[int, List[int]] + inplace_indices : Union[int, List[int]] Specify which arguments should be used for in-place computations. - If `input_indices` is a single integer, it will be made into a singleton list. - Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output + If `inplace_indices` is a single integer, it will be made into a singleton list. + Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output will be an alias of `args[j]`. - If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. - At least one member of `input_indices` must not be -1. + If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. + At least one member of `inplace_indices` must not be -1. out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] The structure info of the call_tir_inplace output. @@ -637,13 +637,13 @@ def call_inplace_packed( args: Expr The arguments for the PackedFunc. - input_indices : Union[int, List[int]] + inplace_indices : Union[int, List[int]] Specify which arguments should be used for in-place computations. - If `input_indices` is a single integer, it will be made into a singleton list. - Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output + If `inplace_indices` is a single integer, it will be made into a singleton list. + Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output will be an alias of `args[j]`. - If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. - At least one member of `input_indices` must not be -1. + If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. + At least one member of `inplace_indices` must not be -1. sinfo_args: Union[StructInfo, List[StructInfo]] The list of structure info arguments (giving the structural info for the returned value). diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index ed2e3753b2fb..c74e06490fc5 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -619,6 +619,111 @@ def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offse tvm.ir.assert_structural_equal(irmodule, Expected) +def test_tensor_ir_inplace_op(): + hidden_size = 4096 + dtype = "float16" + + @T.prim_func + def inplace_take( + var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64 + ): + T.func_attr({"tir.noalias": T.bool(True)}) + vocab_size = T.int64() + weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype) + seq_len = T.int64() + total_seq_len = T.int64() + pos = T.match_buffer(var_pos, (seq_len,), "int32") + embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype) + for ax0, ax1 in T.grid(seq_len, hidden_size): + with T.block("T_take"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(weight[pos[v0], v1], pos[v0]) + T.writes(embeddings[v0, v1]) + embeddings[v0 + offset, v1] = weight[pos[v0], v1] + + class Model(Module): + def test( + self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int + ): + tensor_expr_op_out = op.tensor_ir_op( + inplace_take, + "inplace_take", + args=[embedding_table, input_ids, embedding_dst, offset], + out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype), + ) + return tensor_expr_op_out + + @I.ir_module + class Expected: + @T.prim_func + def inplace_take( + var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64 + ): + T.func_attr({"tir.noalias": T.bool(True)}) + vocab_size = T.int64() + weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype) + seq_len = T.int64() + total_seq_len = T.int64() + pos = T.match_buffer(var_pos, (seq_len,), "int32") + embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype) + for ax0, ax1 in T.grid(seq_len, hidden_size): + with T.block("T_take"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(weight[pos[v0], v1], pos[v0]) + T.writes(embeddings[v0, v1]) + embeddings[v0 + offset, v1] = weight[pos[v0], v1] + + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def test( + embedding_table: R.Tensor(("vocab_size", hidden_size), dtype), + input_ids: R.Tensor(("seq_len",), "int32"), + embedding_dst: R.Tensor(("total_seq_len", hidden_size), dtype), + offset: R.Shape(["offset_1"]), + packed_params: R.Tuple, + ) -> R.Tensor(("total_seq_len", hidden_size), dtype): + total_seq_len = T.int64() + offset_1 = T.int64() + R.func_attr({"num_input": 4}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.inplace_take, + (embedding_table, input_ids, embedding_dst), + out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), + tir_vars=R.shape([offset_1]), + ) + gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 + R.output(gv1) + return gv1 + + m = Model() + irmodule, _ = m.export_tvm( + spec={ + "test": { + "embedding_table": spec.Tensor(["vocab_size", hidden_size], dtype), + "input_ids": spec.Tensor(["seq_len"], "int32"), + "embedding_dst": spec.Tensor(["total_seq_len", hidden_size], dtype), + "offset": int, + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + }, + debug=True, + ) + tvm.ir.assert_structural_equal(irmodule, Expected) + + def test_extern(): class Model(Module): def test(self, q: Tensor, k: Tensor, v: Tensor):