-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
The DefaultCUDATensorCore schedule rule currently support wmma with dtype fp16fp16fp16 and s8s8s32: https://github.com/apache/tvm/blob/main/src/meta_schedule/schedule_rule/schedule_rule.cc#L122-L171
When I use metaschedule to tune a matmul with DefaultCUDATensorCore, the result seems incorrect when TensorCore is used, however when I changed the input/output dtype of matmul which wmma intrinsic in DefaultCUDATensorCore do not support, the result is correct.
Expected behavior
metaschedule produce correct result when using tensorcore
Actual behavior
the result is wrong
Environment
Operating System: Ubnutu-20.04
TVM version: main branch
GPU: nvidia-a100
Steps to reproduce
import tempfile
import numpy as np
import tvm
from tvm.contrib import nvcc
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.target import Target
from tvm.meta_schedule.testing import te_workload
def test_tune_tir_matmul_cuda_tensor_core(in_dtype, out_dtype, n, m, k):
mod = tvm.te.create_prim_func(
te_workload.matmul(n, m, k, in_dtype=in_dtype, out_dtype=out_dtype)
)
target = Target("nvidia/nvidia-a100")
with tempfile.TemporaryDirectory() as work_dir:
database = tune_tir(
mod=mod,
target=target,
work_dir=work_dir,
num_trials_per_iter=32,
max_trials_global=32,
strategy="replay-trace",
)
sch = ms.tir_integration.compile_tir(database, mod, target)
if sch is None:
raise RuntimeError("No valid schedule found!")
ctx = tvm.cuda()
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext():
func = tvm.build(sch.mod["main"], [], "cuda")
# print(func.imported_modules[0].get_source())
# print(sch.mod.script())
print(sch.trace)
a_np = np.random.uniform(-10, 10, size=(n, k)).astype(in_dtype)
b_np = np.random.uniform(-10, 10, size=(k, m)).astype(in_dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=out_dtype), ctx)
func(a, b, c)
np.testing.assert_allclose(
c.asnumpy(),
np.matmul(a_np, b_np, dtype=out_dtype),
rtol=1e-6,
atol=1e-6,
)
print("passed!")
if __name__ == "__main__":
test_tune_tir_matmul_cuda_tensor_core(in_dtype="float32", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct
test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct
test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", out_dtype="float16", n=128, m=128, k=128) # tensor core, incorrect
test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", out_dtype="int32", n=128, m=128, k=128) # tensor core, incorrect
test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", out_dtype="float32", n=128, m=128, k=128) # cuda core, correctmasahi and zxybazh
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug