Skip to content

Commit 625cd27

Browse files
committed
fixed offset factor
1 parent 69e72b6 commit 625cd27

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

python/tvm/tir/tensor_intrin/arm_cpu.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121
@T.prim_func
2222
def dot_product_4x4_i8i8i32_desc(
23-
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
23+
A: T.Buffer((4,), "int8", offset_factor=1),
24+
B: T.Buffer((4, 4), "int8", offset_factor=1),
25+
C: T.Buffer((4,), "int32", offset_factor=1),
2426
) -> None:
2527
with T.block("root"):
2628
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
@@ -36,7 +38,9 @@ def dot_product_4x4_i8i8i32_desc(
3638

3739
@T.prim_func
3840
def dot_product_4x4_i8i8i32_neon(
39-
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
41+
A: T.Buffer((4,), "int8", offset_factor=1),
42+
B: T.Buffer((4, 4), "int8", offset_factor=1),
43+
C: T.Buffer((4,), "int32", offset_factor=1),
4044
) -> None:
4145
with T.block("root"):
4246
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
@@ -92,7 +96,9 @@ def dot_product_4x4_i8i8i32_neon(
9296

9397
@T.prim_func
9498
def dot_product_4x4_i8i8i32_sdot(
95-
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
99+
A: T.Buffer((4,), "int8", offset_factor=1),
100+
B: T.Buffer((4, 4), "int8", offset_factor=1),
101+
C: T.Buffer((4,), "int32", offset_factor=1),
96102
) -> None:
97103
with T.block("root"):
98104
T.reads(C[0:4], A[0:4], B[0:4, 0:4])

python/tvm/tir/tensor_intrin/x86.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
@T.prim_func
2626
def dot_product_16x4_u8i8i32_desc(
27-
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
27+
A: T.Buffer((4,), "uint8", offset_factor=1),
28+
B: T.Buffer((16, 4), "int8", offset_factor=1),
29+
C: T.Buffer((16,), "int32", offset_factor=1),
2830
) -> None:
2931
with T.block("root"):
3032
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
@@ -40,7 +42,9 @@ def dot_product_16x4_u8i8i32_desc(
4042

4143
@T.prim_func
4244
def dot_product_16x4_u8i8i32_vnni(
43-
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
45+
A: T.Buffer((4,), "uint8", offset_factor=1),
46+
B: T.Buffer((16, 4), "int8", offset_factor=1),
47+
C: T.Buffer((16,), "int32", offset_factor=1),
4448
) -> None:
4549
with T.block("root"):
4650
T.reads(C[0:16], A[0:4], B[0:16, 0:4])

0 commit comments

Comments
 (0)