Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,75 @@ def tensor_ir_op(
)


def tensor_ir_inplace_op(
func: _tir.PrimFunc,
name_hint: str,
args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
inplace_indices: Union[int, List[int]],
out: OutType,
) -> OutType:
"""Create a `call_tir_inplace` binding with given PrimFunc

Parameters
----------
func : _tir.PrimFunc
The PrimFunc to call.

name_hint : str
Name hint.

args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
The arguments to pass to the PrimFunc.

inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.

out : Union[Tensor, List[Tensor]]
The output tensors.

Returns
-------
result : Tensor
The result tensor
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel

call_tir_args, tir_vars = [], []
if not isinstance(args, (tuple, list)):
args = [args]

for arg in args:
if isinstance(arg, Tensor):
call_tir_args.append(arg._expr)
elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
tir_vars.append(arg)
else:
raise TypeError(
"Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or"
f" PrimExpr, but got {type(arg)}"
)

if isinstance(out, Tensor):
out_sinfo = [out._expr.struct_info]
else:
out_sinfo = [x._expr.struct_info for x in out]

bb = BlockBuilder.current()
global_var = bb.add_func(func, name_hint)

return wrap_nested(
bb.emit(
rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars)
),
name=name_hint,
)


def extern(
name: str,
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def call_tir_inplace(
args : Expr
The input arguments.

input_indices : Union[int, List[int]]
inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `input_indices` is a single integer, it will be made into a singleton list.
Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `input_indices` must not be -1.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.

out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_tir_inplace output.
Expand Down Expand Up @@ -637,13 +637,13 @@ def call_inplace_packed(
args: Expr
The arguments for the PackedFunc.

input_indices : Union[int, List[int]]
inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `input_indices` is a single integer, it will be made into a singleton list.
Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `input_indices` must not be -1.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.

sinfo_args: Union[StructInfo, List[StructInfo]]
The list of structure info arguments (giving the structural info for the returned value).
Expand Down
105 changes: 105 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,111 @@ def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offse
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_tensor_ir_inplace_op():
hidden_size = 4096
dtype = "float16"

@T.prim_func
def inplace_take(
var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
):
T.func_attr({"tir.noalias": T.bool(True)})
vocab_size = T.int64()
weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
seq_len = T.int64()
total_seq_len = T.int64()
pos = T.match_buffer(var_pos, (seq_len,), "int32")
embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
for ax0, ax1 in T.grid(seq_len, hidden_size):
with T.block("T_take"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(weight[pos[v0], v1], pos[v0])
T.writes(embeddings[v0, v1])
embeddings[v0 + offset, v1] = weight[pos[v0], v1]

class Model(Module):
def test(
self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int
):
tensor_expr_op_out = op.tensor_ir_op(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to call op.tensor_ir_inplace_op here? It doesn't look like you are testing the new nn.op.tensor_ir_inplace_op added in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooooops sorry my bad. Is that updated? Or I can find a chance to update the test next time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I was just hoping to use this PR to guide my use of nn.op.tensor_ir_inplace_op but then became nervous about using the feature when I saw it wasn't tested. I haven't made any change to update the test and would appreciate the update whenever you have cycles to come back to this. It's not blocking me though, so low priority is fine. Thanks @MasterJH5574

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for letting me know!

inplace_take,
"inplace_take",
args=[embedding_table, input_ids, embedding_dst, offset],
out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype),
)
return tensor_expr_op_out

@I.ir_module
class Expected:
@T.prim_func
def inplace_take(
var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
):
T.func_attr({"tir.noalias": T.bool(True)})
vocab_size = T.int64()
weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
seq_len = T.int64()
total_seq_len = T.int64()
pos = T.match_buffer(var_pos, (seq_len,), "int32")
embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
for ax0, ax1 in T.grid(seq_len, hidden_size):
with T.block("T_take"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(weight[pos[v0], v1], pos[v0])
T.writes(embeddings[v0, v1])
embeddings[v0 + offset, v1] = weight[pos[v0], v1]

@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv

@R.function
def test(
embedding_table: R.Tensor(("vocab_size", hidden_size), dtype),
input_ids: R.Tensor(("seq_len",), "int32"),
embedding_dst: R.Tensor(("total_seq_len", hidden_size), dtype),
offset: R.Shape(["offset_1"]),
packed_params: R.Tuple,
) -> R.Tensor(("total_seq_len", hidden_size), dtype):
total_seq_len = T.int64()
offset_1 = T.int64()
R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
lv1 = R.call_tir(
cls.inplace_take,
(embedding_table, input_ids, embedding_dst),
out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
tir_vars=R.shape([offset_1]),
)
gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
R.output(gv1)
return gv1

m = Model()
irmodule, _ = m.export_tvm(
spec={
"test": {
"embedding_table": spec.Tensor(["vocab_size", hidden_size], dtype),
"input_ids": spec.Tensor(["seq_len"], "int32"),
"embedding_dst": spec.Tensor(["total_seq_len", hidden_size], dtype),
"offset": int,
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
},
debug=True,
)
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_extern():
class Model(Module):
def test(self, q: Tensor, k: Tensor, v: Tensor):
Expand Down