Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

@LeiWang1999 LeiWang1999 commented Apr 29, 2024

Major changes of this pull request:

  • Change the fp8-related test requires_cuda_compute_version from 9 to 8.9 (since sm_89 ada architecture also supports fp8 tensor cores, which is the platform I have tested on).
  • Improve fp8 vector load/store capabilities; previously, TVM only supported float8x4/2/1 load, but this PR introduces support for float8x8/16 load.
  • Refactor the interface of get_mma_intrin_group and get_mma_intrin functions, as the prior implementation assumed that input A and input B were of the same datatype. However, fp8 tensor cores can process combinations like e5m2e5m2, e5m2e4m3, e4m3e4m3, or e4m3e5m2. Note: This change may affect code in MLC that utilizes get_mma_intrin_group.
  • Implement support for fp8 mma code generation and associated tests.

Check out the correctness:

import tvm
from tvm import te
import numpy as np
import tvm.testing
from tvm.script import tir as T
import os
from tvm.tir.tensor_intrin.cuda import (
    get_mma_intrin_group,
    shared_16x16_to_ldmatrix_32x8_layout,
    shared_32x16_to_ldmatrix_32x16_layout,
    shared_16x32_to_ldmatrix_32x16_layout,
)

M = 1024
N = 1024
K = 1024

BM = 64
BN = 64
BK = 64
warp_size = 32
block_row_warps = 2
block_col_warps = 4

indtype = "e4m3_float8"
out_dtype = "float32"
# indtype = "int8"
# out_dtype = "int32"
intrin_group = get_mma_intrin_group(
    "shared",
    "global",
    in_dtype=indtype,
    out_dtype=out_dtype,
    trans_a=False,
    trans_b=True,
    not_use_mma_store_intrinic=False,
)

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, K], dtype=indtype)
        B = T.match_buffer(b, [N, K], dtype=indtype)
        C = T.match_buffer(c, [M, N], dtype=out_dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.int32(0)
                C[vi, vj] = C[vi, vj] + \
                    A[vi, vk].astype(out_dtype) * B[vj, vk].astype(out_dtype)


ir_module = MyModule
print(ir_module)
sch = tvm.tir.Schedule(ir_module, debug_mask="all")

block_b = sch.get_block("B")

(i, j, k) = sch.get_loops(block_b)
by, i = sch.split(i, factors=[None, BM])
bx, j = sch.split(j, factors=[None, BN])
bk, k = sch.split(k, factors=[None, BK])

sch.reorder(by, bx, bk, i, j, k)

sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")


block_b_tz, block_b_inner_i = sch.split(
    i, factors=[block_row_warps, None])

block_b_ty, block_b_inner_j = sch.split(
    j, factors=[block_col_warps, None])

sch.reorder(block_b_tz, block_b_ty, bk, block_b_inner_i, block_b_inner_j, k)

sch.bind(block_b_tz, "threadIdx.z")
sch.bind(block_b_ty, "threadIdx.y")

# schdule the shared memory

def fetch_to_shared(block, idx):
    block_read = sch.cache_read(block, idx, "shared")
    sch.compute_at(block_read, bk)
    vector_size = 16
    fused = sch.fuse(*sch.get_loops(block_read)[-2:])
    _, f_1, f_2, f_3 = sch.split(
        fused, factors=[None, block_col_warps, warp_size, vector_size])
    sch.bind(f_2, "threadIdx.x")
    sch.bind(f_1, "threadIdx.y")
    sch.vectorize(f_3)
    offset = 0
    sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

# schedule A
fetch_to_shared(block_b, 0)
# schedule B
fetch_to_shared(block_b, 1)


# blockize for mma tensorize

mma_m = 16
mma_n = 16
mma_k = 32

block_b_inner_i, block_b_inner_i_tc = sch.split(
    block_b_inner_i, factors=[None, mma_m])
block_b_inner_j, block_b_inner_j_tc = sch.split(
    block_b_inner_j, factors=[None, mma_n])
k, k_tc = sch.split(k, factors=[None, mma_k])

sch.reorder(block_b_inner_i, block_b_inner_j,
            k, block_b_inner_i_tc, block_b_inner_j_tc, k_tc)

A_warp = sch.cache_read(block_b, 0, "warp")
B_warp = sch.cache_read(block_b, 1, "warp")
sch.compute_at(A_warp, k)
sch.compute_at(B_warp, k)
C_warp = sch.cache_write(block_b, 0, "warp")
sch.reverse_compute_at(C_warp, block_b_ty)

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, mma_m])
jo, ji = sch.split(jj, factors=[None, mma_n])
sch.reorder(io, jo, ii, ji)


def tile_wmma_fragment(block_read, height, width):
    i, j = sch.get_loops(block_read)[-2:]
    return i

loop_a = tile_wmma_fragment(A_warp, mma_m, mma_k)

loop_b = tile_wmma_fragment(B_warp, mma_n, mma_k)

block_init_c = sch.decompose_reduction(
    block_b, bk)

def index_map_A(i, j):
    return (
        i // 16,
        j // 32,
        *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
    )

def index_map_B(i, j):
    return (
        i // 32,
        j // 16,
        *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
    )

def index_map_C(i, j):
    return (
        i // 16,
        j // 16,
        *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
    )


sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_A)
sch.transform_layout(C_warp, ("read", 0), index_map_C)


sch.tensorize(loop_a, intrin_group["load_a"])
sch.tensorize(loop_b, intrin_group["load_b"])

# _test_block = sch.get_block("")
sch.tensorize(block_b_inner_i_tc, intrin_group["compute"])

sch.tensorize(sch.get_loops(block_init_c)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(C_warp)[-2], intrin_group["store"])


ctx = tvm.cuda(0)
cuda_mod = tvm.build(sch.mod, target="cuda")

def map_numpy_type(intype):
    
    typemap = {
        'e4m3_float8': 'float8_e4m3fn',
        'e5m2_float8': 'float8_e5m2',
    }
    if intype in typemap:
        return typemap[intype]
    else:
        return intype

numpytype_a = map_numpy_type(indtype)
numpytype_b = map_numpy_type(indtype)
numpytype_c = map_numpy_type(out_dtype)
a = np.random.uniform(low=-5, high=5, size=(M*K)).reshape((M, K)).astype(numpytype_a)
b = np.random.uniform(low=-5, high=5, size=(N*K)).reshape((K, N)).astype(numpytype_b)
out = np.matmul(a, b.T)

print("numpy_simulated:", out)

cuda_a = tvm.nd.array(a, ctx)
cuda_b = tvm.nd.array(b, ctx)
cuda_c = tvm.nd.array(np.zeros((M, N)).astype(numpytype_c), ctx)
cuda_mod(cuda_a, cuda_b, cuda_c)

print("codegen:", cuda_c)
num_flops = 2 * M * K * N
num_runs = 1
timer_cuda_mod = cuda_mod.time_evaluator(
    cuda_mod.entry_name, ctx, number=num_runs)

t = timer_cuda_mod(cuda_a, cuda_b, cuda_c).mean

GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." %
      (num_runs, t * 1e3, GFLOPS))

expected output:

numpy_simulated: [[-410.33817   -30.429443 -470.51312  ...   64.58632  -381.49658
    14.920105]
 [  56.357788  744.9746    -29.630783 ...  -44.779022  298.5943
   -24.109558]
 [  77.765305 -426.8894    286.35736  ...   10.655792 -129.63507
   232.30026 ]
 ...
 [  39.094635  -47.508118 -225.59912  ...  775.10614  -109.92264
   268.50952 ]
 [-813.8422    111.21069  -316.5697   ...  455.90875   -37.09839
   478.28406 ]
 [ 122.78345   148.104     340.1291   ... -304.5721   -115.578735
  -639.9563  ]]
codegen: [[-410.28125    -30.441406  -470.09375   ...   64.66406   -381.5
    14.8203125]
 [  56.367188   744.8125     -29.597656  ...  -44.695312   298.625
   -24.148438 ]
 [  77.65625   -426.71875    286.3125    ...   10.746094  -129.6875
   232.34375  ]
 ...
 [  39.191406   -47.539062  -225.57812   ...  774.9375    -109.875
   268.46875  ]
 [-813.625      111.109375  -316.46875   ...  455.96875    -37.08203
   478.0625   ]
 [ 122.75       148.10938    339.84375   ... -304.5       -115.546875
  -639.8125   ]]

Please CC @yzh119

@Hzfengsy
Copy link
Member

cc @vinx13

@LeiWang1999
Copy link
Contributor Author

I just removed the modification for the signature in get_mma_intrin_group, so this pr no longer affects items which uses this function. I believe this pr is now ready for review. @Hzfengsy

@spectrometerHBH spectrometerHBH merged commit 32929c6 into apache:main Feb 8, 2025
JoieAli pushed a commit to JoieAli/mcTVM that referenced this pull request Jul 8, 2025
* [FP8] SM89 (Ada) can also support fp8.

* extend fp8 vectorize to f16

* Supprt fp8

* Support fp8 mma codegen.

* Fix test_tir_schedule_tensorize_ldmatrix_mma_numeric.py to use tvm.testing.main()

* lint fix

* lint fix

* CUDA Lint fix

* Fix formatting in codegen_cuda.cc

* lint fix for ptc.cc

* update comments

* chore: Refactor CUDA tensor intrinsics function signature

* remove debug print
JoieAli pushed a commit to JoieAli/mcTVM that referenced this pull request Jul 8, 2025
* [FP8] SM89 (Ada) can also support fp8.

* extend fp8 vectorize to f16

* Supprt fp8

* Support fp8 mma codegen.

* Fix test_tir_schedule_tensorize_ldmatrix_mma_numeric.py to use tvm.testing.main()

* lint fix

* lint fix

* CUDA Lint fix

* Fix formatting in codegen_cuda.cc

* lint fix for ptc.cc

* update comments

* chore: Refactor CUDA tensor intrinsics function signature

* remove debug print
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
* [FP8] SM89 (Ada) can also support fp8.

* extend fp8 vectorize to f16

* Supprt fp8

* Support fp8 mma codegen.

* Fix test_tir_schedule_tensorize_ldmatrix_mma_numeric.py to use tvm.testing.main()

* lint fix

* lint fix

* CUDA Lint fix

* Fix formatting in codegen_cuda.cc

* lint fix for ptc.cc

* update comments

* chore: Refactor CUDA tensor intrinsics function signature

* remove debug print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants