diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 025605333138..b7c4eb1d42ec 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -334,24 +334,37 @@ inline Array BufferOffset(const BufferNode* n, Array index, return offsets; } -Buffer Buffer::GetFlattenedBuffer() const { - auto self = operator->(); - +static void ValidateAxisSeparators(const Array& axis_separators, size_t buffer_dim) { // These checks ensure that all output axes contain at least one // input axis. - for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) { - auto sep = self->axis_separators[i]->value; - auto next_sep = self->axis_separators[i + 1]->value; - ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order."; - } - if (self->axis_separators.size()) { - auto first_sep = self->axis_separators[0]->value; - ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, " - << "so that first output axis contains at least one input axis"; - auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value; - ICHECK_LT(last_sep, self->shape.size()) - << "Last output axis must contain at least one input axis."; + for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { + auto sep = axis_separators[i]->value; + auto next_sep = axis_separators[i + 1]->value; + CHECK_LE(sep, next_sep) << "ValueError: " + << "Axis separators must be in increasing order, " + << "but axis_separators[" << i << "] = " << sep + << " is greater than or equal to axis_separators[" << (i + 1) + << "] = " << next_sep << "."; + } + if (axis_separators.size()) { + auto first_sep = axis_separators[0]->value; + CHECK_GE(first_sep, 0) << "ValueError: " + << "All axis separators must be non-negative. " + << "However, the axis_separators[0] = " << first_sep; + auto last_sep = axis_separators[axis_separators.size() - 1]->value; + CHECK_LE(last_sep, buffer_dim) + << "ValueError: " + << "All axis separators must be within the range " + << "0 <= sep <= buffer_dim. " + << "However, the last axis_separators[" << (axis_separators.size() - 1) + << "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim; } +} + +Buffer Buffer::GetFlattenedBuffer() const { + auto self = operator->(); + + ValidateAxisSeparators(self->axis_separators, self->shape.size()); Array output_shape; if (self->strides.size()) { @@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array ICHECK(data->type_annotation.as()->element_type.as()) << "Variable " << data->name_hint << " does not point to a primitive."; + ValidateAxisSeparators(axis_separators, shape.size()); + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index f1e9106a635b..8b95e0dc622f 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { if (it != buffer_var_map_.end()) { const Buffer& new_source_buffer = it->second; Buffer new_target_buffer = match_buffer->buffer; - new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; - if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) { - LOG(WARNING) - << "Target buffer in match_buffer doesn't have the same dimensionality as its source " - "buffer. `axis_separators` for the target buffer might be incorrect."; + + if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) { + new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; + } else { + new_target_buffer.CopyOnWrite()->axis_separators = + Array(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); + LOG(WARNING) << "Buffer view " << new_target_buffer + << " has different dimensionality than backing buffer " << new_source_buffer + << ". The `axis_separators` for " << new_target_buffer << "." + << "`axis_separators` for the view might be incorrect."; } buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer; return MatchBufferRegion(new_target_buffer, diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index 1ab7662b0b6b..b4b773197b14 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod(): A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): - tvm.ir.assert_structural_equal( - index_simplified, index_direct - ), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct) + ( + tvm.ir.assert_structural_equal(index_simplified, index_direct), + "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct), + ) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod @@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators(): tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) +def test_invalid_axis_separators_raises_exception(): + with pytest.raises(ValueError): + tvm.tir.decl_buffer([1], axis_separators=[1, 2]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py index 76a6ade42f50..788e17e77146 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py @@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "flo for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) B_subregion0[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) C[vi, vj] = B_subregion1[()] + T.float32(1)