Skip to content

[Bug] MetaSchedule produce incorrect result when TensorCore is used #13204

@wllqwzx

Description

@wllqwzx

The DefaultCUDATensorCore schedule rule currently support wmma with dtype fp16fp16fp16 and s8s8s32: https://github.com/apache/tvm/blob/main/src/meta_schedule/schedule_rule/schedule_rule.cc#L122-L171

When I use metaschedule to tune a matmul with DefaultCUDATensorCore, the result seems incorrect when TensorCore is used, however when I changed the input/output dtype of matmul which wmma intrinsic in DefaultCUDATensorCore do not support, the result is correct.

Expected behavior

metaschedule produce correct result when using tensorcore

Actual behavior

the result is wrong

Environment

Operating System: Ubnutu-20.04
TVM version: main branch
GPU: nvidia-a100

Steps to reproduce

import tempfile
import numpy as np
import tvm
from tvm.contrib import nvcc
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.target import Target
from tvm.meta_schedule.testing import te_workload

def test_tune_tir_matmul_cuda_tensor_core(in_dtype, out_dtype, n, m, k):
    mod = tvm.te.create_prim_func(
        te_workload.matmul(n, m, k, in_dtype=in_dtype, out_dtype=out_dtype)
    )
    target = Target("nvidia/nvidia-a100")

    with tempfile.TemporaryDirectory() as work_dir:
        database = tune_tir(
            mod=mod,
            target=target,
            work_dir=work_dir,
            num_trials_per_iter=32,
            max_trials_global=32,
            strategy="replay-trace",
        )
        sch = ms.tir_integration.compile_tir(database, mod, target)
        if sch is None:
            raise RuntimeError("No valid schedule found!")
        ctx = tvm.cuda()
        if nvcc.have_tensorcore(ctx.compute_version):
            with tvm.transform.PassContext():
                func = tvm.build(sch.mod["main"], [], "cuda")
                # print(func.imported_modules[0].get_source())
                # print(sch.mod.script())
                print(sch.trace)
            a_np = np.random.uniform(-10, 10, size=(n, k)).astype(in_dtype)
            b_np = np.random.uniform(-10, 10, size=(k, m)).astype(in_dtype)
            a = tvm.nd.array(a_np, ctx)
            b = tvm.nd.array(b_np, ctx)
            c = tvm.nd.array(np.zeros((n, m), dtype=out_dtype), ctx)
            func(a, b, c)
            np.testing.assert_allclose(
                c.asnumpy(),
                np.matmul(a_np, b_np, dtype=out_dtype),
                rtol=1e-6,
                atol=1e-6,
            )
            print("passed!")

if __name__ == "__main__":
    test_tune_tir_matmul_cuda_tensor_core(in_dtype="float32", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct
    test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct
    test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", out_dtype="float16", n=128, m=128, k=128) # tensor core, incorrect
    test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", out_dtype="int32", n=128, m=128, k=128)  # tensor core, incorrect
    test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", out_dtype="float32", n=128, m=128, k=128)  # cuda core, correct

cc @vinx13 @junrushao @masahi @zxybazh

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions