|
42 | 42 | from tvm.target.target import Target |
43 | 43 | from tvm.tir.schedule import BlockRV, Schedule |
44 | 44 | from tvm.tir.schedule.trace import Trace |
| 45 | +from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN |
| 46 | + |
45 | 47 |
|
46 | 48 | logging.basicConfig() |
47 | 49 | logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) |
@@ -332,57 +334,6 @@ def get_output(data, lib): |
332 | 334 | assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) |
333 | 335 |
|
334 | 336 |
|
335 | | -# Tensorized intrinsic description and VNNI-specific implementation. |
336 | | -# Equivalent to the ones in topi/x86/tensor_intrin.py |
337 | | - |
338 | | - |
339 | | -@T.prim_func |
340 | | -def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: |
341 | | - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) |
342 | | - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) |
343 | | - C = T.match_buffer(c, (16,), "int32", offset_factor=1) |
344 | | - |
345 | | - with T.block("root"): |
346 | | - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) |
347 | | - T.writes(C[0:16]) |
348 | | - for i in T.serial(0, 16): |
349 | | - with T.init(): |
350 | | - C[i] = T.int32(0) |
351 | | - for k in T.serial(0, 4): |
352 | | - with T.block("update"): |
353 | | - vi, vk = T.axis.remap("SR", [i, k]) |
354 | | - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") |
355 | | - |
356 | | - |
357 | | -@T.prim_func |
358 | | -def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None: |
359 | | - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) |
360 | | - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) |
361 | | - C = T.match_buffer(c, (16,), "int32", offset_factor=1) |
362 | | - |
363 | | - with T.block("root"): |
364 | | - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) |
365 | | - T.writes(C[0:16]) |
366 | | - |
367 | | - A_u8x4 = A.vload([0], "uint8x4") |
368 | | - A_i32 = T.reinterpret(A_u8x4, dtype="int32") |
369 | | - |
370 | | - B_i8x64 = B.vload([0, 0], dtype="int8x64") |
371 | | - B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") |
372 | | - |
373 | | - C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += |
374 | | - T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), |
375 | | - T.uint32(0), |
376 | | - T.int32x16(0), |
377 | | - T.broadcast(A_i32, 16), |
378 | | - B_i32x16, |
379 | | - dtype="int32x16", |
380 | | - ) |
381 | | - |
382 | | - |
383 | | -VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake" |
384 | | - |
385 | | - |
386 | 337 | def schedule_dense(dense_block, M, do_tune, sch): |
387 | 338 | """ |
388 | 339 | Manually schedule a dense block, created from TE compute op via CreatePrimFunc, |
@@ -550,10 +501,6 @@ def schedule_fn(task, sch): |
550 | 501 |
|
551 | 502 | @pytest.mark.skip("Requires cascadelake") |
552 | 503 | def test_tune_relay_manual_tir_vnni(): |
553 | | - # Register a pair of an intrinsic description for 16x4 dot product, and its |
554 | | - # VNNI-specific implementation. |
555 | | - tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni) |
556 | | - |
557 | 504 | manual_tir_common(do_tune=False) |
558 | 505 |
|
559 | 506 | """ |
|
0 commit comments