Skip to content

Commit c2e314c

Browse files
committed
tuned int8 4k, 91 TOPS
1 parent 94d9d96 commit c2e314c

File tree

3 files changed

+4
-7
lines changed

3 files changed

+4
-7
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
821821

822822
if (trans && op->dtype.bits() == 8) {
823823
std::string smem_stride = this->PrintExpr(op->args[6]);
824-
LOG(INFO) << op->dtype;
825-
CHECK(num == 4);
824+
ICHECK(num == 4);
826825
os << "for (int i = 0; i < 16; ++i) {\n";
827826
os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
828827
<< "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +

src/tir/transforms/lower_warp_memory.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,13 +492,11 @@ namespace transform {
492492
Pass LowerWarpMemory() {
493493
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
494494
auto* n = f.CopyOnWrite();
495-
// LOG(INFO) << f;
496495
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
497496
int warp_size = 32;
498497
WarpMemoryRewriter warp_memory_rewriter(warp_size);
499498
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
500499
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
501-
LOG(INFO) << f;
502500
return f;
503501
};
504502
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});

tests/python/unittest/test_mma_16x8x32_4k_tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,9 @@ def schedule(sch: tir.Schedule):
323323
k_factors = sch.sample_perfect_tile(k, n=3)
324324
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
325325
else:
326-
i_factors = [4, 8, 2, 4, 1]
327-
j_factors = [1, 64, 2, 1, 2]
328-
k_factors = [64, 2, 1]
326+
i_factors = [1, 32, 1, 4, 2]
327+
j_factors = [8, 4, 4, 2, 1]
328+
k_factors = [32, 2, 2]
329329

330330
num_ty = i_factors[2] * j_factors[2]
331331

0 commit comments

Comments
 (0)