Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,17 @@ class BlockBuilderImpl : public BlockBuilderNode {
auto it = shape_var_map.find(shape_var);
if (it == shape_var_map.end()) {
shape_var_map.Set(shape_var, shape_expr);
// Expose the shape variable as non-negative, for purposes
// of shape inference. In many cases, knowning that the
// shape variable is non-negative allows for simpler
// expressions for dynamic shapes.
analyzer_.MarkGlobalNonNegValue(shape_var);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth commenting that indicating this results in much simpler symbolic shapes in many cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and I've added a comment on it.

} else {
const PrimExpr& old_shape_expr = (*it).second;
CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr))
<< "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs "
<< shape_expr;
}
shape_var_map.Set(shape_var, shape_expr);
}
}
scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)}));
Expand Down
50 changes: 30 additions & 20 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,41 +846,48 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
int axis =
data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);

if (const auto* p_indices = attrs->indices_or_sections.as<ArrayNode>()) {
if (auto opt_indices = attrs->indices_or_sections.as<Array<IntImm>>()) {
auto p_indices = opt_indices.value();
// When there is not index, return the input tensor's struct info.
if (p_indices->size() == 0) {
if (p_indices.size() == 0) {
return TupleStructInfo({data_sinfo});
}
// Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape.
if (data_shape == nullptr) {
return TupleStructInfo(Array<StructInfo>(
p_indices->size() + 1,
p_indices.size() + 1,
TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)));
}

ICHECK_NE(axis, -1);
const auto* axis_length = data_shape->values[axis].as<IntImmNode>();
// Fall back to unknown shape when the input tensor shape at the given axis is symbolic.
if (axis_length == nullptr) {
return TupleStructInfo(Array<StructInfo>(
p_indices->size() + 1,
TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)));
}

// Only do output shape inference when all the indices and the total length are integers.
Array<IntImm> indices = GetRef<Array<IntImm>>(p_indices);
IntImm zero(DataType::Int(64), /*value=*/0);
indices.insert(indices.begin(), zero);
indices.insert(indices.end(), Downcast<IntImm>(data_shape->values[axis]));

std::vector<StructInfo> output_sinfo;
output_sinfo.reserve(indices.size() - 1);
for (int i = 0; i + 1 < static_cast<int>(indices.size()); ++i) {
PrimExpr l = tvm::max(zero, indices[i]);
PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]);
for (size_t i = 0; i < p_indices.size() + 1; i++) {
PrimExpr left;
if (i == 0) {
left = zero;
} else {
left = p_indices[i - 1];
}

PrimExpr right;
if (i < p_indices.size()) {
right = p_indices[i];
} else {
right = data_shape->values[axis];
}

left = tvm::min(tvm::max(left, 0), data_shape->values[axis]);
right = tvm::min(tvm::max(right, 0), data_shape->values[axis]);

PrimExpr split_dim = right - left;
split_dim = tvm::max(split_dim, 0);
split_dim = ctx->GetAnalyzer()->Simplify(split_dim);

Array<PrimExpr> shape = data_shape->values;
shape.Set(axis, tvm::max(zero, r - l));
shape.Set(axis, split_dim);
output_sinfo.push_back(
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice));
}
Expand All @@ -899,6 +906,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
}
ICHECK_NE(axis, -1);
PrimExpr split_len = ceildiv(data_shape->values[axis], n_section);
split_len = ctx->GetAnalyzer()->Simplify(split_len);

// Construct struct info for tensors except the last one.
Array<PrimExpr> shape = data_shape->values;
Expand All @@ -907,7 +915,9 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice));

// Construct struct info for the last tensor.
shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1));
PrimExpr last_split_len = data_shape->values[axis] - split_len * (n_section - 1);
last_split_len = ctx->GetAnalyzer()->Simplify(last_split_len);
shape.Set(axis, last_split_len);
output_sinfo.push_back(
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice));
return TupleStructInfo(output_sinfo);
Expand Down
125 changes: 118 additions & 7 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op, VDevice
from tvm.script import relax as R
from tvm.script import relax as R, tir as T


def test_op_correctness():
Expand Down Expand Up @@ -1832,9 +1832,9 @@ def test_split_infer_struct_info_by_indices_shape_symbolic():
relax.op.split(x, [10, 20], axis=1),
relax.TupleStructInfo(
[
relax.TensorStructInfo(dtype="float32", ndim=2),
relax.TensorStructInfo(dtype="float32", ndim=2),
relax.TensorStructInfo(dtype="float32", ndim=2),
relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b), 0)], dtype="float32"),
relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b), 0)], dtype="float32"),
relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"),
]
),
)
Expand Down Expand Up @@ -1987,9 +1987,9 @@ def test_split_infer_struct_info_by_n_section_shape_symbolic():
relax.op.split(x, 3, axis=1),
relax.TupleStructInfo(
[
relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), "float32"),
relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"),
]
),
)
Expand Down Expand Up @@ -2176,6 +2176,117 @@ def test_split_indices_or_sections_int64():
assert split1.attrs.indices_or_sections.dtype == "int64"


def test_split_infer_struct_info():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((16, 4)))
y = relax.Var("y", R.Tensor((16, 4), "float32"))
z = relax.Var("z", R.Tensor((n, 16)))
w = relax.Var("w", R.Tensor((n + 5, 16)))

# All relax shape variables are non-negative. When a scope
# begins, any TIR variables that are used as shape variables are
# declared to be non-negative `tvm.arith.Analyzer`. Because
# `relax.op.split` clamps the indices to be within the bounds of
# the axis being split, simplifying with non-negative shape
# variables can result in much simpler shapes.
#
# For example, an axis of size `n`, split on the range from 2 to 5
# has size `T.max(T.min(5, n + 5) - T.min(2, n + 5), 0)`. If it
# is known that `n >= 0`, then this simplifies down to `3`.
bb.begin_scope([x, y, z, w])

_check_inference(
bb,
relax.op.split(x, 1),
R.Tuple(
R.Tensor([16, 4]),
),
)
_check_inference(
bb,
relax.op.split(x, 2),
R.Tuple(
R.Tensor([8, 4]),
R.Tensor([8, 4]),
),
)
# Uneven splits are allowed, with the last split being smaller than the others.
_check_inference(
bb,
relax.op.split(x, 3),
R.Tuple(
R.Tensor([6, 4]),
R.Tensor([6, 4]),
R.Tensor([4, 4]),
),
)

# Dtype of result is inherited from the tensor
_check_inference(
bb,
relax.op.split(y, 2),
R.Tuple(
R.Tensor([8, 4], "float32"),
R.Tensor([8, 4], "float32"),
),
)

# Axis can be explicitly specified. Otherwise, defaults to axis=0.
_check_inference(
bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]), R.Tensor([16, 2]))
)

# Split points can be explicitly specified
_check_inference(
bb,
relax.op.split(x, [2]),
R.Tuple(
R.Tensor([2, 4]),
R.Tensor([14, 4]),
),
)
_check_inference(
bb,
relax.op.split(x, [2, 5]),
R.Tuple(
R.Tensor([2, 4]),
R.Tensor([3, 4]),
R.Tensor([11, 4]),
),
)

# Splitting a dynamic axis is allowed, and propagates the shape to the output
_check_inference(
bb,
relax.op.split(z, 2),
R.Tuple(
R.Tensor([(n + 1) // 2, 16]),
R.Tensor([n - (n + 1) // 2, 16]),
),
)
_check_inference(
bb,
relax.op.split(z, 3),
R.Tuple(
R.Tensor([(n + 2) // 3, 16]),
R.Tensor([(n + 2) // 3, 16]),
R.Tensor([n - (n + 2) // 3 * 2, 16]),
),
)

# Splitting a dynamic axis at specific indices is allowed.
_check_inference(
bb,
relax.op.split(w, [2, 5]),
R.Tuple(
R.Tensor((2, 16)),
R.Tensor((3, 16)),
R.Tensor((n, 16)),
),
)


def test_split_infer_struct_info_non_integer_indices():
bb = relax.BlockBuilder()
a = tir.Var("c", "int64")
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relax/test_transform_combine_parallel_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,5 +525,52 @@ def expected(
tvm.ir.assert_structural_equal(after, expected)


def test_dynamic_rhs():
@R.function(private=True)
def before(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, "M"), "float32"),
):
M = T.int64()
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
out = (lv0, lv1)
R.output(out)
return out

@R.function(private=True)
def expected(
x: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, 640), dtype="float32"),
w1: R.Tensor((640, "M"), dtype="float32"),
) -> R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, "M"), dtype="float32")
):
M = T.int64()
with R.dataflow():
lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w0, w1), axis=1)
lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
x, lv, out_dtype="float32"
)
lv2: R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, M), dtype="float32"),
) = R.split(lv1, indices_or_sections=[640], axis=2)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
lv1_1: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
out: R.Tuple(
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, M), dtype="float32"),
) = (lv0, lv1_1)
R.output(out)
return out

after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]

tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
tvm.testing.main()