From ed7c05de8e1062c0b066055e639d3bf6456516db Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Dec 2023 22:20:29 +0000 Subject: [PATCH 1/2] [Unity] Infer struct info for relax.op.split on dynamic-sized index --- src/relax/ir/block_builder.cc | 2 +- src/relax/op/tensor/manipulate.cc | 50 +++++--- tests/python/relax/test_op_manipulate.py | 118 ++++++++++++++++-- .../test_transform_combine_parallel_matmul.py | 47 +++++++ 4 files changed, 189 insertions(+), 28 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index f74434bd7453..75f7459015aa 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -170,13 +170,13 @@ class BlockBuilderImpl : public BlockBuilderNode { auto it = shape_var_map.find(shape_var); if (it == shape_var_map.end()) { shape_var_map.Set(shape_var, shape_expr); + analyzer_.MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " << shape_expr; } - shape_var_map.Set(shape_var, shape_expr); } } scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 12342aecf284..ad2a812c8254 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -846,41 +846,48 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { int axis = data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); - if (const auto* p_indices = attrs->indices_or_sections.as()) { + if (auto opt_indices = attrs->indices_or_sections.as>()) { + auto p_indices = opt_indices.value(); // When there is not index, return the input tensor's struct info. - if (p_indices->size() == 0) { + if (p_indices.size() == 0) { return TupleStructInfo({data_sinfo}); } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { return TupleStructInfo(Array( - p_indices->size() + 1, + p_indices.size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } ICHECK_NE(axis, -1); - const auto* axis_length = data_shape->values[axis].as(); - // Fall back to unknown shape when the input tensor shape at the given axis is symbolic. - if (axis_length == nullptr) { - return TupleStructInfo(Array( - p_indices->size() + 1, - TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); - } - // Only do output shape inference when all the indices and the total length are integers. - Array indices = GetRef>(p_indices); IntImm zero(DataType::Int(64), /*value=*/0); - indices.insert(indices.begin(), zero); - indices.insert(indices.end(), Downcast(data_shape->values[axis])); std::vector output_sinfo; - output_sinfo.reserve(indices.size() - 1); - for (int i = 0; i + 1 < static_cast(indices.size()); ++i) { - PrimExpr l = tvm::max(zero, indices[i]); - PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]); + for (size_t i = 0; i < p_indices.size() + 1; i++) { + PrimExpr left; + if (i == 0) { + left = zero; + } else { + left = p_indices[i - 1]; + } + + PrimExpr right; + if (i < p_indices.size()) { + right = p_indices[i]; + } else { + right = data_shape->values[axis]; + } + + left = tvm::min(tvm::max(left, 0), data_shape->values[axis]); + right = tvm::min(tvm::max(right, 0), data_shape->values[axis]); + + PrimExpr split_dim = right - left; + split_dim = tvm::max(split_dim, 0); + split_dim = ctx->GetAnalyzer()->Simplify(split_dim); Array shape = data_shape->values; - shape.Set(axis, tvm::max(zero, r - l)); + shape.Set(axis, split_dim); output_sinfo.push_back( TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); } @@ -899,6 +906,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { } ICHECK_NE(axis, -1); PrimExpr split_len = ceildiv(data_shape->values[axis], n_section); + split_len = ctx->GetAnalyzer()->Simplify(split_len); // Construct struct info for tensors except the last one. Array shape = data_shape->values; @@ -907,7 +915,9 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); // Construct struct info for the last tensor. - shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1)); + PrimExpr last_split_len = data_shape->values[axis] - split_len * (n_section - 1); + last_split_len = ctx->GetAnalyzer()->Simplify(last_split_len); + shape.Set(axis, last_split_len); output_sinfo.push_back( TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); return TupleStructInfo(output_sinfo); diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 6c0fbcf22770..f672234a3afa 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -20,7 +20,7 @@ from tvm import relax, tir from tvm import TVMError from tvm.ir import Op, VDevice -from tvm.script import relax as R +from tvm.script import relax as R, tir as T def test_op_correctness(): @@ -1832,9 +1832,9 @@ def test_split_infer_struct_info_by_indices_shape_symbolic(): relax.op.split(x, [10, 20], axis=1), relax.TupleStructInfo( [ - relax.TensorStructInfo(dtype="float32", ndim=2), - relax.TensorStructInfo(dtype="float32", ndim=2), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b), 0)], dtype="float32"), + relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b), 0)], dtype="float32"), + relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"), ] ), ) @@ -1987,9 +1987,9 @@ def test_split_infer_struct_info_by_n_section_shape_symbolic(): relax.op.split(x, 3, axis=1), relax.TupleStructInfo( [ - relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), - relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), - relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), "float32"), + relax.TensorStructInfo((a, (b + 2) // 3), "float32"), + relax.TensorStructInfo((a, (b + 2) // 3), "float32"), + relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"), ] ), ) @@ -2176,6 +2176,110 @@ def test_split_indices_or_sections_int64(): assert split1.attrs.indices_or_sections.dtype == "int64" +def test_split_infer_struct_info(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((16, 4))) + y = relax.Var("y", R.Tensor((16, 4), "float32")) + z = relax.Var("z", R.Tensor((n, 16))) + w = relax.Var("w", R.Tensor((n + 5, 16))) + + _check_inference( + bb, + relax.op.split(x, 1), + R.Tuple( + R.Tensor([16, 4]), + ), + ) + _check_inference( + bb, + relax.op.split(x, 2), + R.Tuple( + R.Tensor([8, 4]), + R.Tensor([8, 4]), + ), + ) + # Uneven splits are allowed, with the last split being smaller than the others. + _check_inference( + bb, + relax.op.split(x, 3), + R.Tuple( + R.Tensor([6, 4]), + R.Tensor([6, 4]), + R.Tensor([4, 4]), + ), + ) + + # Dtype of result is inherited from the tensor + _check_inference( + bb, + relax.op.split(y, 2), + R.Tuple( + R.Tensor([8, 4], "float32"), + R.Tensor([8, 4], "float32"), + ), + ) + + # Axis can be explicitly specified. Otherwise, defaults to axis=0. + _check_inference( + bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]), R.Tensor([16, 2])) + ) + + # Split points can be explicitly specified + _check_inference( + bb, + relax.op.split(x, [2]), + R.Tuple( + R.Tensor([2, 4]), + R.Tensor([14, 4]), + ), + ) + _check_inference( + bb, + relax.op.split(x, [2, 5]), + R.Tuple( + R.Tensor([2, 4]), + R.Tensor([3, 4]), + R.Tensor([11, 4]), + ), + ) + + # Splitting a dynamic axis is allowed, and propagates the shape to the output + _check_inference( + bb, + relax.op.split(z, 2), + R.Tuple( + R.Tensor([(n + 1) // 2, 16]), + R.Tensor([n - (n + 1) // 2, 16]), + ), + ) + _check_inference( + bb, + relax.op.split(z, 3), + R.Tuple( + R.Tensor([(n + 2) // 3, 16]), + R.Tensor([(n + 2) // 3, 16]), + R.Tensor([n - (n + 2) // 3 * 2, 16]), + ), + ) + + # Spliting a dynamic axis at specific indices is allowed. The + # algebraic form here isn't the cleanest, primarily because the + # test case doesn't know that `n` is a shape variable. When + # occurring in a relax function, `n` would be marked with + # `analyzer_.MarkGlobalNonNegValue`, which would make the shapes + # simplify to `[(2,16), (3,16), (n,16)]`. + _check_inference( + bb, + relax.op.split(w, [2, 5]), + R.Tuple( + R.Tensor((T.max(T.min(2, n + 5) - T.min(0, n + 5), 0), 16)), + R.Tensor((T.max(T.min(5, n + 5) - T.min(2, n + 5), 0), 16)), + R.Tensor((T.max(n, 0), 16)), + ), + ) + + def test_split_infer_struct_info_non_integer_indices(): bb = relax.BlockBuilder() a = tir.Var("c", "int64") diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index b06eddd2bb67..7e7f2328f3b3 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -525,5 +525,52 @@ def expected( tvm.ir.assert_structural_equal(after, expected) +def test_dynamic_rhs(): + @R.function(private=True) + def before( + x: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, "M"), "float32"), + ): + M = T.int64() + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + out = (lv0, lv1) + R.output(out) + return out + + @R.function(private=True) + def expected( + x: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, 640), dtype="float32"), + w1: R.Tensor((640, "M"), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, "M"), dtype="float32") + ): + M = T.int64() + with R.dataflow(): + lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w0, w1), axis=1) + lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul( + x, lv, out_dtype="float32" + ) + lv2: R.Tuple( + R.Tensor((2, 1024, 640), dtype="float32"), + R.Tensor((2, 1024, M), dtype="float32"), + ) = R.split(lv1, indices_or_sections=[640], axis=2) + lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0] + lv1_1: R.Tensor((2, 1024, M), dtype="float32") = lv2[1] + out: R.Tuple( + R.Tensor((2, 1024, 640), dtype="float32"), + R.Tensor((2, 1024, M), dtype="float32"), + ) = (lv0, lv1_1) + R.output(out) + return out + + after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"] + + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main() From 4decca578309978851457ccf47465354099c62b6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Feb 2024 14:49:11 +0000 Subject: [PATCH 2/2] Update based on review comments --- src/relax/ir/block_builder.cc | 4 ++++ tests/python/relax/test_op_manipulate.py | 25 +++++++++++++++--------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 75f7459015aa..b39beae7403a 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -170,6 +170,10 @@ class BlockBuilderImpl : public BlockBuilderNode { auto it = shape_var_map.find(shape_var); if (it == shape_var_map.end()) { shape_var_map.Set(shape_var, shape_expr); + // Expose the shape variable as non-negative, for purposes + // of shape inference. In many cases, knowning that the + // shape variable is non-negative allows for simpler + // expressions for dynamic shapes. analyzer_.MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index f672234a3afa..ddb92725d438 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -2184,6 +2184,18 @@ def test_split_infer_struct_info(): z = relax.Var("z", R.Tensor((n, 16))) w = relax.Var("w", R.Tensor((n + 5, 16))) + # All relax shape variables are non-negative. When a scope + # begins, any TIR variables that are used as shape variables are + # declared to be non-negative `tvm.arith.Analyzer`. Because + # `relax.op.split` clamps the indices to be within the bounds of + # the axis being split, simplifying with non-negative shape + # variables can result in much simpler shapes. + # + # For example, an axis of size `n`, split on the range from 2 to 5 + # has size `T.max(T.min(5, n + 5) - T.min(2, n + 5), 0)`. If it + # is known that `n >= 0`, then this simplifies down to `3`. + bb.begin_scope([x, y, z, w]) + _check_inference( bb, relax.op.split(x, 1), @@ -2263,19 +2275,14 @@ def test_split_infer_struct_info(): ), ) - # Spliting a dynamic axis at specific indices is allowed. The - # algebraic form here isn't the cleanest, primarily because the - # test case doesn't know that `n` is a shape variable. When - # occurring in a relax function, `n` would be marked with - # `analyzer_.MarkGlobalNonNegValue`, which would make the shapes - # simplify to `[(2,16), (3,16), (n,16)]`. + # Splitting a dynamic axis at specific indices is allowed. _check_inference( bb, relax.op.split(w, [2, 5]), R.Tuple( - R.Tensor((T.max(T.min(2, n + 5) - T.min(0, n + 5), 0), 16)), - R.Tensor((T.max(T.min(5, n + 5) - T.min(2, n + 5), 0), 16)), - R.Tensor((T.max(n, 0), 16)), + R.Tensor((2, 16)), + R.Tensor((3, 16)), + R.Tensor((n, 16)), ), )