Skip to content

Commit e08df2a

Browse files
committed
tensorized C_warp init
1 parent ae06789 commit e08df2a

File tree

5 files changed

+73
-14
lines changed

5 files changed

+73
-14
lines changed

include/tvm/tir/builtin.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ TVM_DLL const Op& ptx_mma_sp();
633633
TVM_DLL const Op& ptx_ldmatrix();
634634

635635
TVM_DLL const Op& mma_store();
636+
TVM_DLL const Op& mma_fill();
636637

637638
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
638639
/*!

src/target/source/codegen_cuda.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,16 +822,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
822822
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
823823
smem_ptr, smem_elem_offset);
824824
} else if (op->op.same_as(builtin::mma_store())) {
825-
std::string dst = this->PrintExpr(op->args[1]);
826-
std::string src = this->PrintExpr(op->args[2]);
827-
std::string src_offset = this->PrintExpr(op->args[3]);
828-
std::string stride = this->PrintExpr(op->args[4]);
825+
std::string dst = this->PrintExpr(op->args[2]);
826+
std::string src = this->PrintExpr(op->args[3]);
827+
std::string src_offset = this->PrintExpr(op->args[4]);
828+
std::string stride = this->PrintExpr(op->args[5]);
829829

830830
os << "for (int i = 0; i < 4; ++i) {\n";
831831
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
832832
<< " + (threadIdx.x % 4) * 2 + i % 2]"
833833
<< " = " << src << "[" << src_offset << " + i];\n";
834834
os << "}\n";
835+
} else if (op->op.same_as(builtin::mma_fill())) {
836+
std::string num_elem = this->PrintExpr(op->args[0]);
837+
std::string dst = this->PrintExpr(op->args[1]);
838+
std::string dst_offset = this->PrintExpr(op->args[2]);
839+
840+
os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
841+
os << dst << "[" << dst_offset << " + i] = 0.0;" ;
842+
os << "}\n";
835843
} else {
836844
CodeGenC::VisitExpr_(op, os);
837845
}

src/tir/op/builtin.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
250250
TIR_DEFINE_BUILTIN_FUNC(mma_store)
251251
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
252252

253+
TIR_DEFINE_BUILTIN_FUNC(mma_fill)
254+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
255+
253256
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
254257
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
255258

src/tir/transforms/lower_warp_memory.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,13 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
117117
if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) {
118118
int num_matrix = op->args[1].as<IntImmNode>()->value;
119119
warp_coeff_ = num_matrix * 2;
120+
} else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
121+
LOG(INFO) << op->args[0];
122+
auto* ptr = op->args[0].as<IntImmNode>();
123+
CHECK(ptr);
124+
warp_coeff_ = ptr->value;;
120125
}
126+
121127
StmtExprVisitor::VisitExpr_(op);
122128
}
123129

@@ -284,9 +290,20 @@ class WarpAccessRewriter : protected StmtExprMutator {
284290
if (op->op.same_as(builtin::mma_store())) {
285291
Array<PrimExpr> new_args = op->args;
286292
PrimExpr local_offset, group;
287-
if (op->args[2].get() == buffer_) {
288-
std::tie(local_offset, group) = SplitIndexByGroup(op->args[3]);
289-
new_args.Set(3, local_offset);
293+
if (op->args[3].get() == buffer_) {
294+
std::tie(local_offset, group) = SplitIndexByGroup(op->args[4]);
295+
new_args.Set(4, local_offset);
296+
return Call(op->dtype, op->op, new_args);
297+
}
298+
return GetRef<PrimExpr>(op);
299+
}
300+
301+
if (op->op.same_as(builtin::mma_fill())) {
302+
Array<PrimExpr> new_args = op->args;
303+
PrimExpr local_offset, group;
304+
if (op->args[1].get() == buffer_) {
305+
std::tie(local_offset, group) = SplitIndexByGroup(op->args[2]);
306+
new_args.Set(2, local_offset);
290307
return Call(op->dtype, op->op, new_args);
291308
}
292309
return GetRef<PrimExpr>(op);

tests/python/unittest/test_mma_16x8x8_4k_tune.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,43 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
196196
tx = T.env_thread("threadIdx.x")
197197
T.launch_thread(tx, 32)
198198

199-
T.evaluate(T.mma_store("m16n8", C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"))
199+
T.evaluate(T.mma_store(16, 8, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"))
200+
201+
202+
@T.prim_func
203+
def mma_fill_desc(a: T.handle) -> None:
204+
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp")
205+
206+
with T.block("root"):
207+
T.reads()
208+
T.writes(C_warp[0:32, 0:4])
209+
for i0, i1 in T.grid(32, 4):
210+
with T.block("C_warp"):
211+
i_init = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
212+
j_init = T.axis.spatial(8, (i0 % 4) * 2 + i1 % 2)
213+
T.reads()
214+
T.writes(C_warp[i_init % 8 * 4 + j_init % 8 // 2, i_init % 16 // 8 * 2 + j_init % 2])
215+
C_warp[i_init % 8 * 4 + j_init % 8 // 2, i_init % 16 // 8 * 2 + j_init % 2] = T.float32(0)
216+
217+
218+
@T.prim_func
219+
def mma_fill_impl(a: T.handle) -> None:
220+
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp", offset_factor=1)
221+
222+
with T.block("root"):
223+
T.reads()
224+
T.writes(C_warp[0:32, 0:4])
225+
tx = T.env_thread("threadIdx.x")
226+
T.launch_thread(tx, 32)
227+
228+
T.evaluate(T.mma_fill(4, C_warp.data, C_warp.elem_offset, dtype="float32"))
200229

201230

202231
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
203232
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
204233
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
205234
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
235+
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
206236

207237
N = 4096
208238
M = 4096
@@ -381,7 +411,8 @@ def lambda_b(i, j):
381411
sch.reorder(f_1, f_2, f_0, f_3)
382412
fused_1 = sch.fuse(f_1, f_2)
383413
fused_2 = sch.fuse(f_0, f_3)
384-
sch.bind(fused_1, "threadIdx.x")
414+
# sch.bind(fused_1, "threadIdx.x")
415+
sch.tensorize(fused_1, "mma_fill")
385416

386417
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
387418
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
@@ -394,7 +425,6 @@ def lambda_b(i, j):
394425
# return
395426

396427
sch.tensorize(fused_1, "mma_store")
397-
# sch.bind(fused_1, "threadIdx.x")
398428

399429

400430
ir_module = tvm.IRModule({"main": workload})
@@ -440,7 +470,7 @@ def lambda_b(i, j):
440470
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
441471
print("ok")
442472

443-
# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
444-
# gflops = (N * M * K) * 2 / 1e9
445-
# time_ms = evaluator(a, b, c).mean * 1e3
446-
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
473+
evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
474+
gflops = (N * M * K) * 2 / 1e9
475+
time_ms = evaluator(a, b, c).mean * 1e3
476+
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 commit comments

Comments
 (0)