Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,9 @@ TVM_DLL Pass InstallDebugSpans();
TVM_DLL Pass UnifyThreadBinding();

/*!
* A pass to merge multiple TIR-level dynamic shared memory allocations into one
* A pass to merge multiple TIR-level shared memory allocations into one
*/
TVM_DLL Pass MergeDynamicSharedMemoryAllocations();
TVM_DLL Pass MergeSharedMemoryAllocations();

/*!
* \brief This pass is post-scheduling pass to convert all
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,16 +1000,16 @@ def UnifyThreadBinding():
return _ffi_api.UnifyThreadBinding() # type: ignore


def MergeDynamicSharedMemoryAllocations():
"""This pass merges multiple TIR-level dynamic shared memory allocations
def MergeSharedMemoryAllocations():
"""This pass merges multiple TIR-level shared memory allocations
into one allocation.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore


def ConvertForLoopsToSerial():
Expand Down
3 changes: 2 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
Expand Down Expand Up @@ -584,7 +585,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
pass_list.push_back(tir::transform::LowerIntrin());
// Convert Function to IRModule
transform::PassContext pass_ctx = transform::PassContext::Current();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
*/

/*!
* \file merge_dynamic_shared_memory_allocations.cc
* \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation.
* This pass merges multiple TIR-level dynamic shared memory allocations into one allocation.
* \file merge_shared_memory_allocations.cc
* \brief Each GPU kernel is allowed to have only one dynamic or static shared memory allocation.
* This pass merges multiple TIR-level dynamic or static shared memory allocations into one
* allocation.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
Expand All @@ -45,6 +46,11 @@ bool IsDynamicSharedMemory(Var buffer_var) {
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
}

bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == "";
}

/*!
* \brief collect the mapping from the buffer var to its allocate
*/
Expand All @@ -53,11 +59,15 @@ class AllocateCollector : public StmtExprVisitor {
void VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
// The mapping from the original buffer var to its allocate
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode*, const AllocateNode*> static_shmem_allocs_;
};

// Find a linear pattern of storage access
Expand All @@ -73,8 +83,9 @@ class AllocateCollector : public StmtExprVisitor {
// The storage need to be kept alive between Allocate and last access.
// The free point is only inserted at the same scope of Allocate.
//
class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
public:
explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true) : is_dynamic_(is_dynamic) {}
/*! \brief record the touch list of statement. */
struct StmtEntry {
// The statement
Expand Down Expand Up @@ -112,7 +123,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
}
}
Expand Down Expand Up @@ -143,7 +154,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
}
}
Expand All @@ -164,7 +175,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
}
}
Expand Down Expand Up @@ -217,6 +228,12 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
std::unordered_map<const VarNode*, AllocEntry> alloc_info_;

private:
// Wrapper function to determine if the shared memory allocation for a variable is appropriate.
bool IsAppropriateSharedMemory(const Var& var) {
return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var);
}
// Whether do dyanmic analysis.
bool is_dynamic_{true};
// Whether already in thread env.
bool in_thread_env_{false};
// The scope stack.
Expand All @@ -226,18 +243,23 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
/*!
* \brief merge the buffers whose live range has no intersection and rewrite the body
*/
class DynamicSharedMemoryRewriter : public StmtExprMutator {
class SharedMemoryRewriter : public StmtExprMutator {
public:
explicit DynamicSharedMemoryRewriter(
const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
: dyn_shmem_allocs_{dyn_shmem_allocs} {}
explicit SharedMemoryRewriter(
const std::unordered_map<const VarNode*, const AllocateNode*>& shmem_allocs,
bool is_dynamic = true)
: is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} {
if (!is_dynamic) {
merged_buf_var_ = Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared"));
}
}

/*!
* \brief plan the memory reuse for all the buffer allocated in the statement
* \param stmt the statement
*/
void PlanReuse(const Stmt& stmt) {
DynSharedMemLinearAccessPatternFinder finder;
void PlanReuse(const Stmt& stmt, bool is_dynamic = true) {
SharedMemLinearAccessPatternFinder finder(is_dynamic);
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_);
Expand All @@ -263,7 +285,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
for (const StorageEntry* e : all_entry) {
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
for (const VarNode* buffer : e->allocs[i]) {
const AllocateNode* alloc = dyn_shmem_allocs_[buffer];
const AllocateNode* alloc = shmem_allocs_[buffer];
align[i] = std::max(align[i], alloc->dtype.bytes());
}
}
Expand All @@ -274,7 +296,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
PrimExpr inner_offset = 0;
for (const VarNode* buffer : e->allocs[i]) {
const AllocateNode* alloc = dyn_shmem_allocs_[buffer];
const AllocateNode* alloc = shmem_allocs_[buffer];
buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
inner_offset += alloc->extents[0] * alloc->dtype.bytes();
inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]);
Expand All @@ -293,7 +315,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
}

Stmt VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
if (IsAppropriateSharedMemory(op->buffer_var)) {
return StmtExprMutator::VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
Expand All @@ -319,9 +341,9 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {

template <typename Node>
Node VisitBufferAccess(Node node) {
if (IsDynamicSharedMemory(node->buffer->data)) {
if (IsAppropriateSharedMemory(node->buffer->data)) {
ICHECK_EQ(node->indices.size(), 1)
<< "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
<< "MergeSharedMemoryAllocations expects flat memory buffers, "
<< "and is to be run after "
<< "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
Array<PrimExpr> indices = {node->indices[0] +
Expand All @@ -342,10 +364,10 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
return it->second;
}

if (IsDynamicSharedMemory(buffer->data)) {
if (IsAppropriateSharedMemory(buffer->data)) {
ICHECK_EQ(buffer->shape.size(), 1)
<< "Buffer " << buffer << " has shape " << buffer->shape << ". "
<< "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
<< "MergeSharedMemoryAllocations expects flat memory buffers, "
<< "and is to be run after "
<< "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
auto writer = buffer.CopyOnWrite();
Expand All @@ -361,7 +383,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
Var buffer = Downcast<Var>(op->args[1]);
if (!IsDynamicSharedMemory(buffer)) {
if (!IsAppropriateSharedMemory(buffer)) {
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
Expand All @@ -381,7 +403,12 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
return indexdiv(it->second, dtype.bytes());
}

using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry;
// Wrapper function to determine if the shared memory allocation for a variable is appropriate.
bool IsAppropriateSharedMemory(const Var& var) {
return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var);
}

using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
struct StorageEntry {
// The constant size of the buffer in bits, only used if it is constant
uint64_t const_nbits{0};
Expand Down Expand Up @@ -458,8 +485,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
// In both cases, we need to handle the gen event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
for (const VarNode* var : it->second.gen) {
ICHECK(dyn_shmem_allocs_.count(var));
const AllocateNode* alloc = dyn_shmem_allocs_[var];
ICHECK(shmem_allocs_.count(var));
const AllocateNode* alloc = shmem_allocs_[var];
StorageEntry* dst_entry = FindAlloc(alloc);
alloc_map_[var] = dst_entry;
}
Expand Down Expand Up @@ -578,10 +605,12 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
sym_free_list_.push_back(e);
}
}
// Wheather enable dyanmic analysis.
bool is_dynamic_{true};
// The var for the merged buffer
Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
// The mapping from the original buffer var to its allocate
std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
std::unordered_map<const VarNode*, const AllocateNode*> shmem_allocs_;
// The size of the merged buffer
PrimExpr merged_alloc_size_{0};
// The mapping from the original buffer var to its offset in the merged buffer
Expand All @@ -602,30 +631,36 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
support::Arena arena_;
};

Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
rewriter.PlanReuse(stmt);
return rewriter(std::move(stmt));
stmt = rewriter(std::move(stmt));
}
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false);
rewriter.PlanReuse(stmt, false);
stmt = rewriter(std::move(stmt));
}
return stmt;
}

namespace transform {

Pass MergeDynamicSharedMemoryAllocations() {
Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
auto* n = f.CopyOnWrite();
n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {});
return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", {});
}

TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations")
.set_body_typed(MergeDynamicSharedMemoryAllocations);
TVM_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations")
.set_body_typed(MergeSharedMemoryAllocations);

} // namespace transform
} // namespace tir
Expand Down
19 changes: 11 additions & 8 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,13 @@ class StoragePlanRewriter : public StmtExprMutator {
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;

Stmt Rewrite(Stmt stmt, bool detect_inplace) {
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) {
detect_inplace_ = detect_inplace;
// plan the rewrite
LinearAccessPatternFinder finder;
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse);
all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
Expand Down Expand Up @@ -816,7 +816,8 @@ class StoragePlanRewriter : public StmtExprMutator {

// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq,
const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) {
const std::unordered_map<const VarNode*, AllocEntry>& alloc_info,
bool enable_reuse = true) {
std::unordered_set<const VarNode*> inplace_flag;

for (size_t i = 0; i < seq.size(); ++i) {
Expand Down Expand Up @@ -863,8 +864,8 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
if (dst_entry == nullptr) {
dst_entry =
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions);
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
entry.num_physical_dimensions, enable_reuse);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
Expand Down Expand Up @@ -917,7 +918,8 @@ class StoragePlanRewriter : public StmtExprMutator {
}

StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
const StorageScope& scope, size_t num_physical_dimensions) {
const StorageScope& scope, size_t num_physical_dimensions,
bool enable_reuse = true) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
Expand All @@ -940,7 +942,7 @@ class StoragePlanRewriter : public StmtExprMutator {
(scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32));

if (is_small_array || !is_flat_memory_space) {
if (!enable_reuse || is_small_array || !is_flat_memory_space) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}

Expand Down Expand Up @@ -1702,8 +1704,9 @@ namespace transform {

Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
Expand Down
Loading