Skip to content

Commit b15daba

Browse files
committed
[Unity] Infer struct info for relax.op.split on dynamic-sized index
1 parent e4b1d68 commit b15daba

File tree

4 files changed

+189
-28
lines changed

4 files changed

+189
-28
lines changed

src/relax/ir/block_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,13 @@ class BlockBuilderImpl : public BlockBuilderNode {
170170
auto it = shape_var_map.find(shape_var);
171171
if (it == shape_var_map.end()) {
172172
shape_var_map.Set(shape_var, shape_expr);
173+
analyzer_.MarkGlobalNonNegValue(shape_var);
173174
} else {
174175
const PrimExpr& old_shape_expr = (*it).second;
175176
CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr))
176177
<< "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs "
177178
<< shape_expr;
178179
}
179-
shape_var_map.Set(shape_var, shape_expr);
180180
}
181181
}
182182
scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)}));

src/relax/op/tensor/manipulate.cc

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -846,41 +846,48 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
846846
int axis =
847847
data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);
848848

849-
if (const auto* p_indices = attrs->indices_or_sections.as<ArrayNode>()) {
849+
if (auto opt_indices = attrs->indices_or_sections.as<Array<IntImm>>()) {
850+
auto p_indices = opt_indices.value();
850851
// When there is not index, return the input tensor's struct info.
851-
if (p_indices->size() == 0) {
852+
if (p_indices.size() == 0) {
852853
return TupleStructInfo({data_sinfo});
853854
}
854855
// Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape.
855856
if (data_shape == nullptr) {
856857
return TupleStructInfo(Array<StructInfo>(
857-
p_indices->size() + 1,
858+
p_indices.size() + 1,
858859
TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)));
859860
}
860861

861862
ICHECK_NE(axis, -1);
862-
const auto* axis_length = data_shape->values[axis].as<IntImmNode>();
863-
// Fall back to unknown shape when the input tensor shape at the given axis is symbolic.
864-
if (axis_length == nullptr) {
865-
return TupleStructInfo(Array<StructInfo>(
866-
p_indices->size() + 1,
867-
TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)));
868-
}
869863

870-
// Only do output shape inference when all the indices and the total length are integers.
871-
Array<IntImm> indices = GetRef<Array<IntImm>>(p_indices);
872864
IntImm zero(DataType::Int(64), /*value=*/0);
873-
indices.insert(indices.begin(), zero);
874-
indices.insert(indices.end(), Downcast<IntImm>(data_shape->values[axis]));
875865

876866
std::vector<StructInfo> output_sinfo;
877-
output_sinfo.reserve(indices.size() - 1);
878-
for (int i = 0; i + 1 < static_cast<int>(indices.size()); ++i) {
879-
PrimExpr l = tvm::max(zero, indices[i]);
880-
PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]);
867+
for (size_t i = 0; i < p_indices.size() + 1; i++) {
868+
PrimExpr left;
869+
if (i == 0) {
870+
left = zero;
871+
} else {
872+
left = p_indices[i - 1];
873+
}
874+
875+
PrimExpr right;
876+
if (i < p_indices.size()) {
877+
right = p_indices[i];
878+
} else {
879+
right = data_shape->values[axis];
880+
}
881+
882+
left = tvm::min(tvm::max(left, 0), data_shape->values[axis]);
883+
right = tvm::min(tvm::max(right, 0), data_shape->values[axis]);
884+
885+
PrimExpr split_dim = right - left;
886+
split_dim = tvm::max(split_dim, 0);
887+
split_dim = ctx->GetAnalyzer()->Simplify(split_dim);
881888

882889
Array<PrimExpr> shape = data_shape->values;
883-
shape.Set(axis, tvm::max(zero, r - l));
890+
shape.Set(axis, split_dim);
884891
output_sinfo.push_back(
885892
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice));
886893
}
@@ -899,6 +906,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
899906
}
900907
ICHECK_NE(axis, -1);
901908
PrimExpr split_len = ceildiv(data_shape->values[axis], n_section);
909+
split_len = ctx->GetAnalyzer()->Simplify(split_len);
902910

903911
// Construct struct info for tensors except the last one.
904912
Array<PrimExpr> shape = data_shape->values;
@@ -907,7 +915,9 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
907915
n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice));
908916

909917
// Construct struct info for the last tensor.
910-
shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1));
918+
PrimExpr last_split_len = data_shape->values[axis] - split_len * (n_section - 1);
919+
last_split_len = ctx->GetAnalyzer()->Simplify(last_split_len);
920+
shape.Set(axis, last_split_len);
911921
output_sinfo.push_back(
912922
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice));
913923
return TupleStructInfo(output_sinfo);

tests/python/relax/test_op_manipulate.py

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tvm import relax, tir
2121
from tvm import TVMError
2222
from tvm.ir import Op, VDevice
23-
from tvm.script import relax as R
23+
from tvm.script import relax as R, tir as T
2424

2525

2626
def test_op_correctness():
@@ -1832,9 +1832,9 @@ def test_split_infer_struct_info_by_indices_shape_symbolic():
18321832
relax.op.split(x, [10, 20], axis=1),
18331833
relax.TupleStructInfo(
18341834
[
1835-
relax.TensorStructInfo(dtype="float32", ndim=2),
1836-
relax.TensorStructInfo(dtype="float32", ndim=2),
1837-
relax.TensorStructInfo(dtype="float32", ndim=2),
1835+
relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b), 0)], dtype="float32"),
1836+
relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b), 0)], dtype="float32"),
1837+
relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"),
18381838
]
18391839
),
18401840
)
@@ -1987,9 +1987,9 @@ def test_split_infer_struct_info_by_n_section_shape_symbolic():
19871987
relax.op.split(x, 3, axis=1),
19881988
relax.TupleStructInfo(
19891989
[
1990-
relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
1991-
relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
1992-
relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), "float32"),
1990+
relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
1991+
relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
1992+
relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"),
19931993
]
19941994
),
19951995
)
@@ -2176,6 +2176,110 @@ def test_split_indices_or_sections_int64():
21762176
assert split1.attrs.indices_or_sections.dtype == "int64"
21772177

21782178

2179+
def test_split_infer_struct_info():
2180+
bb = relax.BlockBuilder()
2181+
n = tir.Var("n", "int64")
2182+
x = relax.Var("x", R.Tensor((16, 4)))
2183+
y = relax.Var("y", R.Tensor((16, 4), "float32"))
2184+
z = relax.Var("z", R.Tensor((n, 16)))
2185+
w = relax.Var("w", R.Tensor((n + 5, 16)))
2186+
2187+
_check_inference(
2188+
bb,
2189+
relax.op.split(x, 1),
2190+
R.Tuple(
2191+
R.Tensor([16, 4]),
2192+
),
2193+
)
2194+
_check_inference(
2195+
bb,
2196+
relax.op.split(x, 2),
2197+
R.Tuple(
2198+
R.Tensor([8, 4]),
2199+
R.Tensor([8, 4]),
2200+
),
2201+
)
2202+
# Uneven splits are allowed, with the last split being smaller than the others.
2203+
_check_inference(
2204+
bb,
2205+
relax.op.split(x, 3),
2206+
R.Tuple(
2207+
R.Tensor([6, 4]),
2208+
R.Tensor([6, 4]),
2209+
R.Tensor([4, 4]),
2210+
),
2211+
)
2212+
2213+
# Dtype of result is inherited from the tensor
2214+
_check_inference(
2215+
bb,
2216+
relax.op.split(y, 2),
2217+
R.Tuple(
2218+
R.Tensor([8, 4], "float32"),
2219+
R.Tensor([8, 4], "float32"),
2220+
),
2221+
)
2222+
2223+
# Axis can be explicitly specified. Otherwise, defaults to axis=0.
2224+
_check_inference(
2225+
bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]), R.Tensor([16, 2]))
2226+
)
2227+
2228+
# Split points can be explicitly specified
2229+
_check_inference(
2230+
bb,
2231+
relax.op.split(x, [2]),
2232+
R.Tuple(
2233+
R.Tensor([2, 4]),
2234+
R.Tensor([14, 4]),
2235+
),
2236+
)
2237+
_check_inference(
2238+
bb,
2239+
relax.op.split(x, [2, 5]),
2240+
R.Tuple(
2241+
R.Tensor([2, 4]),
2242+
R.Tensor([3, 4]),
2243+
R.Tensor([11, 4]),
2244+
),
2245+
)
2246+
2247+
# Splitting a dynamic axis is allowed, and propagates the shape to the output
2248+
_check_inference(
2249+
bb,
2250+
relax.op.split(z, 2),
2251+
R.Tuple(
2252+
R.Tensor([(n + 1) // 2, 16]),
2253+
R.Tensor([n - (n + 1) // 2, 16]),
2254+
),
2255+
)
2256+
_check_inference(
2257+
bb,
2258+
relax.op.split(z, 3),
2259+
R.Tuple(
2260+
R.Tensor([(n + 2) // 3, 16]),
2261+
R.Tensor([(n + 2) // 3, 16]),
2262+
R.Tensor([n - (n + 2) // 3 * 2, 16]),
2263+
),
2264+
)
2265+
2266+
# Spliting a dynamic axis at specific indices is allowed. The
2267+
# algebraic form here isn't the cleanest, primarily because the
2268+
# test case doesn't know that `n` is a shape variable. When
2269+
# occurring in a relax function, `n` would be marked with
2270+
# `analyzer_.MarkGlobalNonNegValue`, which would make the shapes
2271+
# simplify to `[(2,16), (3,16), (n,16)]`.
2272+
_check_inference(
2273+
bb,
2274+
relax.op.split(w, [2, 5]),
2275+
R.Tuple(
2276+
R.Tensor((T.max(T.min(2, n + 5) - T.min(0, n + 5), 0), 16)),
2277+
R.Tensor((T.max(T.min(5, n + 5) - T.min(2, n + 5), 0), 16)),
2278+
R.Tensor((T.max(n, 0), 16)),
2279+
),
2280+
)
2281+
2282+
21792283
def test_split_infer_struct_info_non_integer_indices():
21802284
bb = relax.BlockBuilder()
21812285
a = tir.Var("c", "int64")

tests/python/relax/test_transform_combine_parallel_matmul.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,5 +525,52 @@ def expected(
525525
tvm.ir.assert_structural_equal(after, expected)
526526

527527

528+
def test_dynamic_rhs():
529+
@R.function(private=True)
530+
def before(
531+
x: R.Tensor((2, 1024, 640), "float32"),
532+
w0: R.Tensor((640, 640), "float32"),
533+
w1: R.Tensor((640, "M"), "float32"),
534+
):
535+
M = T.int64()
536+
with R.dataflow():
537+
lv0 = R.matmul(x, w0)
538+
lv1 = R.matmul(x, w1)
539+
out = (lv0, lv1)
540+
R.output(out)
541+
return out
542+
543+
@R.function(private=True)
544+
def expected(
545+
x: R.Tensor((2, 1024, 640), dtype="float32"),
546+
w0: R.Tensor((640, 640), dtype="float32"),
547+
w1: R.Tensor((640, "M"), dtype="float32"),
548+
) -> R.Tuple(
549+
R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, "M"), dtype="float32")
550+
):
551+
M = T.int64()
552+
with R.dataflow():
553+
lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w0, w1), axis=1)
554+
lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
555+
x, lv, out_dtype="float32"
556+
)
557+
lv2: R.Tuple(
558+
R.Tensor((2, 1024, 640), dtype="float32"),
559+
R.Tensor((2, 1024, M), dtype="float32"),
560+
) = R.split(lv1, indices_or_sections=[640], axis=2)
561+
lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
562+
lv1_1: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
563+
out: R.Tuple(
564+
R.Tensor((2, 1024, 640), dtype="float32"),
565+
R.Tensor((2, 1024, M), dtype="float32"),
566+
) = (lv0, lv1_1)
567+
R.output(out)
568+
return out
569+
570+
after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
571+
572+
tvm.ir.assert_structural_equal(after, expected)
573+
574+
528575
if __name__ == "__main__":
529576
tvm.testing.main()

0 commit comments

Comments
 (0)