Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,24 +334,37 @@ inline Array<PrimExpr> BufferOffset(const BufferNode* n, Array<PrimExpr> index,
return offsets;
}

Buffer Buffer::GetFlattenedBuffer() const {
auto self = operator->();

static void ValidateAxisSeparators(const Array<IntImm>& axis_separators, size_t buffer_dim) {
// These checks ensure that all output axes contain at least one
// input axis.
for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
auto sep = self->axis_separators[i]->value;
auto next_sep = self->axis_separators[i + 1]->value;
ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order.";
}
if (self->axis_separators.size()) {
auto first_sep = self->axis_separators[0]->value;
ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, "
<< "so that first output axis contains at least one input axis";
auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value;
ICHECK_LT(last_sep, self->shape.size())
<< "Last output axis must contain at least one input axis.";
for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
auto sep = axis_separators[i]->value;
auto next_sep = axis_separators[i + 1]->value;
CHECK_LE(sep, next_sep) << "ValueError: "
<< "Axis separators must be in increasing order, "
<< "but axis_separators[" << i << "] = " << sep
<< " is greater than or equal to axis_separators[" << (i + 1)
<< "] = " << next_sep << ".";
}
if (axis_separators.size()) {
auto first_sep = axis_separators[0]->value;
CHECK_GE(first_sep, 0) << "ValueError: "
<< "All axis separators must be non-negative. "
<< "However, the axis_separators[0] = " << first_sep;
auto last_sep = axis_separators[axis_separators.size() - 1]->value;
CHECK_LE(last_sep, buffer_dim)
<< "ValueError: "
<< "All axis separators must be within the range "
<< "0 <= sep <= buffer_dim. "
<< "However, the last axis_separators[" << (axis_separators.size() - 1)
<< "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim;
}
}

Buffer Buffer::GetFlattenedBuffer() const {
auto self = operator->();

ValidateAxisSeparators(self->axis_separators, self->shape.size());

Array<PrimExpr> output_shape;
if (self->strides.size()) {
Expand Down Expand Up @@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
<< "Variable " << data->name_hint << " does not point to a primitive.";

ValidateAxisSeparators(axis_separators, shape.size());

auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
Expand Down
15 changes: 10 additions & 5 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
if (it != buffer_var_map_.end()) {
const Buffer& new_source_buffer = it->second;
Buffer new_target_buffer = match_buffer->buffer;
new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
LOG(WARNING)
<< "Target buffer in match_buffer doesn't have the same dimensionality as its source "
"buffer. `axis_separators` for the target buffer might be incorrect.";

if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) {
new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
} else {
new_target_buffer.CopyOnWrite()->axis_separators =
Array<IntImm>(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0));
LOG(WARNING) << "Buffer view " << new_target_buffer
<< " has different dimensionality than backing buffer " << new_source_buffer
<< ". The `axis_separators` for " << new_target_buffer << "."
<< "`axis_separators` for the view might be incorrect.";
}
buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
return MatchBufferRegion(new_target_buffer,
Expand Down
12 changes: 9 additions & 3 deletions tests/python/tir-base/test_tir_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod():
A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))

def assert_simplified_equal(index_simplified, index_direct):
tvm.ir.assert_structural_equal(
index_simplified, index_direct
), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct)
(
tvm.ir.assert_structural_equal(index_simplified, index_direct),
"index_simplified=%s, index_direct=%s" % (index_simplified, index_direct),
)

idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
Expand Down Expand Up @@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators():
tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32])


def test_invalid_axis_separators_raises_exception():
with pytest.raises(ValueError):
tvm.tir.decl_buffer([1], axis_separators=[1, 2])


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "flo
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0])
B_subregion0[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0])
C[vi, vj] = B_subregion1[()] + T.float32(1)


Expand Down