Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@
namespace tvm {
namespace tir {

class CollectUnmanagedAllocations : public StmtExprVisitor {
class CollectManagedAllocations : public StmtExprVisitor {
public:
void VisitStmt_(const AllocateNode* op) final {
unmanaged_allocations.insert(op->buffer_var.get());
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const AllocateConstNode* op) final {
unmanaged_allocations.insert(op->buffer_var.get());
void VisitStmt_(const BlockNode* op) final {
for (const auto& buf : op->alloc_buffers) {
managed_allocations.insert(buf->data.get());
}
for (const auto& buf : op->match_buffers) {
managed_allocations.insert(buf->buffer->data.get());
}
StmtExprVisitor::VisitStmt_(op);
}

/*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by
* BufferAllocationLocator. */
std::unordered_set<const VarNode*> unmanaged_allocations;
std::unordered_set<const VarNode*> managed_allocations;
};

/*! \brief Collect the allocate buffer order. */
Expand Down Expand Up @@ -108,15 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator {
// since the buffer_lca Map is unordered.
Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func);
std::unordered_set<const VarNode*> arg_buffer_vars;
CollectUnmanagedAllocations collector;
CollectManagedAllocations collector;
collector(func->body);
unmanaged_allocations_ = collector.unmanaged_allocations;

for (const Var& param : func->params) {
if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) {
unmanaged_allocations_.insert(param.get());
}
}
managed_allocations_ = collector.managed_allocations;

for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
Expand All @@ -131,7 +125,7 @@ class BufferAllocationLocator : public StmtExprMutator {
if (arg_buffer_vars.count(buffer->data.get())) {
continue;
}
if (!unmanaged_allocations_.count(buffer->data.get())) {
if (managed_allocations_.count(buffer->data.get())) {
alloc_buffers_[stmt].push_back(buffer);
}
buffer_data_to_buffer_.Set(buffer->data, buffer);
Expand All @@ -152,7 +146,7 @@ class BufferAllocationLocator : public StmtExprMutator {

Array<Buffer> new_block_alloc_bufs;
for (const Buffer& buf : it->second) {
if (!unmanaged_allocations_.count(buf->data.get())) {
if (managed_allocations_.count(buf->data.get())) {
buffer_data_to_buffer_.erase(buf->data);
new_block_alloc_bufs.push_back(buf);
}
Expand Down Expand Up @@ -243,8 +237,8 @@ class BufferAllocationLocator : public StmtExprMutator {
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
/*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */
std::unordered_set<const VarNode*> unmanaged_allocations_;
/*! \brief Buffers that are allocated within a BlockNode, and may be moved. */
std::unordered_set<const VarNode*> managed_allocations_;
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def test_allocate_const_after_tensorize():


def test_buffer_conditional_lowering():
"""
"""Buffers passed as pointer arguments are unmodified

Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass
leaves (Buffer nodes corresponding to pointer-typed PrimFunc arguments)
unchanged, rather than lowering them to `reads`, `writes`, and `alloc_buffer` nodes.
Expand All @@ -434,5 +435,35 @@ def before(A: T.handle("float32")):
_check(before, after)


def test_dltensor_buffer_is_unlowered():
"""Buffers allocated with a LetStmt are unmodified

Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass
leaves (Buffer nodes corresponding to PrimFunc DLTensor arguments)
unchanged, rather than lowering them to `reads`, `writes`, and
`alloc_buffer` nodes.
"""

@T.prim_func
def before(dlpack_handle: T.handle, axis: T.int64) -> T.int64:
ndim: T.int32 = T.tvm_struct_get(dlpack_handle, 0, 5, "int32")
stride_ptr: T.handle("int64") = T.tvm_struct_get(dlpack_handle, 0, 4, "handle")
if T.isnullptr(stride_ptr):
shape_ptr: T.handle("int64") = T.tvm_struct_get(dlpack_handle, 0, 3, "handle")
shape = T.decl_buffer(ndim, "int64", data=shape_ptr)
product = T.decl_buffer([], "int64")
product[()] = 1
for dim in range(axis + 1, ndim):
product[()] = product[()] * shape[dim]
return product[()]
else:
strides = T.decl_buffer(ndim, "int64", data=stride_ptr)
stride: T.int64 = strides[axis]
return stride

after = before
_check(before, after)


if __name__ == "__main__":
tvm.testing.main()