Skip to content

Commit 88b763e

Browse files
committed
refactored existing test using VNNI intrin
1 parent 711a007 commit 88b763e

File tree

2 files changed

+10
-61
lines changed

2 files changed

+10
-61
lines changed

python/tvm/tir/tensor_intrin/vnni.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from tvm.script import tir as T
1919

2020

21+
# Tensorized intrinsic description and VNNI-specific implementation.
22+
# Equivalent to the ones in topi/x86/tensor_intrin.py
23+
24+
2125
@T.prim_func
2226
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
2327
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
@@ -52,9 +56,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
5256
B_i8x64 = B.vload([0, 0], dtype="int8x64")
5357
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
5458

55-
C[
56-
T.ramp(T.int32(0), 1, 16)
57-
] += T.call_llvm_pure_intrin( # Note: this is an update +=
59+
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
5860
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
5961
T.uint32(0),
6062
T.int32x16(0),
@@ -64,6 +66,6 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
6466
)
6567

6668

67-
TensorIntrin.register(
68-
"dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin
69-
)
69+
INTRIN_NAME = "dot_16x1x16_uint8_int8_int32_cascadelake"
70+
71+
TensorIntrin.register(INTRIN_NAME, dot_product_desc, dot_product_intrin)

tests/python/unittest/test_meta_schedule_tune_relay.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from tvm.target.target import Target
4343
from tvm.tir.schedule import BlockRV, Schedule
4444
from tvm.tir.schedule.trace import Trace
45+
from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN
46+
4547

4648
logging.basicConfig()
4749
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
@@ -332,57 +334,6 @@ def get_output(data, lib):
332334
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
333335

334336

335-
# Tensorized intrinsic description and VNNI-specific implementation.
336-
# Equivalent to the ones in topi/x86/tensor_intrin.py
337-
338-
339-
@T.prim_func
340-
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
341-
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
342-
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
343-
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
344-
345-
with T.block("root"):
346-
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
347-
T.writes(C[0:16])
348-
for i in T.serial(0, 16):
349-
with T.init():
350-
C[i] = T.int32(0)
351-
for k in T.serial(0, 4):
352-
with T.block("update"):
353-
vi, vk = T.axis.remap("SR", [i, k])
354-
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
355-
356-
357-
@T.prim_func
358-
def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
359-
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
360-
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
361-
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
362-
363-
with T.block("root"):
364-
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
365-
T.writes(C[0:16])
366-
367-
A_u8x4 = A.vload([0], "uint8x4")
368-
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
369-
370-
B_i8x64 = B.vload([0, 0], dtype="int8x64")
371-
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
372-
373-
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
374-
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
375-
T.uint32(0),
376-
T.int32x16(0),
377-
T.broadcast(A_i32, 16),
378-
B_i32x16,
379-
dtype="int32x16",
380-
)
381-
382-
383-
VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
384-
385-
386337
def schedule_dense(dense_block, M, do_tune, sch):
387338
"""
388339
Manually schedule a dense block, created from TE compute op via CreatePrimFunc,
@@ -550,10 +501,6 @@ def schedule_fn(task, sch):
550501

551502
@pytest.mark.skip("Requires cascadelake")
552503
def test_tune_relay_manual_tir_vnni():
553-
# Register a pair of an intrinsic description for 16x4 dot product, and its
554-
# VNNI-specific implementation.
555-
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni)
556-
557504
manual_tir_common(do_tune=False)
558505

559506
"""

0 commit comments

Comments
 (0)