Skip to content
7 changes: 7 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}

decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n";
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
decl_stream << "#else\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
decl_stream << "#endif\n";

decl_stream << "\n#ifdef _WIN32\n";
decl_stream << " using uint = unsigned int;\n";
decl_stream << " using uchar = unsigned char;\n";
Expand Down
46 changes: 40 additions & 6 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
#if TVM_ENABLE_L2_PREFETCH
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
)";
Expand All @@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value) {
CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" ||
bytes == "1")
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
int src_bytes = {pred_guard} ? {bytes} : 0;
int pred_guard = (int){pred_guard};
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
"{ .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
#if TVM_ENABLE_L2_PREFETCH
" @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
#else
" @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
#endif
" @!p {store_shared};}"
:: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg}
);
}
)";
auto [store_shared, nopreg] = [](const std::string& bytes) {
if (bytes == "16")
return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}",
"\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)");
else if (bytes == "12")
return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)");
else if (bytes == "8")
return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)");
else if (bytes == "4")
return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)");
else if (bytes == "2")
return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)");
else if (bytes == "1")
return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)");
else
return std::make_tuple("", "");
}(bytes);

Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
replacer.register_rule("{store_shared}", store_shared);
replacer.register_rule("{nopreg}", nopreg);
replacer.register_rule("{pred_guard}", predicate_value);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
Expand Down
31 changes: 30 additions & 1 deletion src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator {
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes), predicate_value}));
}
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}

// TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend
Stmt VisitStmt_(const ForNode* loop) final {
// if for loop is not within async_commit_queue_scope
if (!async_queue_id_.has_value()) {
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}

// if for loop is not a memcpy of a contiguous region
// if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior
std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
mem_copy->source->region.size() != 1) {
LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}

// now that we are about to perform the `copy` transform
Expand Down
18 changes: 0 additions & 18 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,23 +879,5 @@ def test_meta(hexagon_session):
)


def test_non_contiguous():
"""Test Non Contiguous memory lowering."""
sch = tvm.tir.Schedule(conv2d_async_non_contig)
target_hexagon = tvm.target.hexagon("v68", link_params=True)
err_rgx = r"Unable to lower async dma due to non contiguous memory access"
# Currently we do not support non contiguous memory access being lowered to
# async dma so we throw an error.
with pytest.raises(tvm.TVMError, match=err_rgx):
with tvm.transform.PassContext(
config={
"tir.use_async_copy": 1,
}
):
tvm.build(
sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
)


if __name__ == "__main__":
tvm.testing.main()
Loading