1818""" Hexagon testing infrastructure """
1919
2020import tvm
21+ from tvm import te
2122import numpy
2223
2324
2425def ceildiv (o , d ):
26+ assert o >= 0
27+ assert d >= 0
2528 return tvm .tir .floordiv (o + d - 1 , d )
2629
2730
28- def get_packed_activation_layout (shape_nhwc , block_shape , packed_C = True ):
29- assert len (shape_nhwc ) == 4
30- shape = [shape_nhwc [0 ]]
31- off_h , off_w , off_c = block_shape
32- shape .append (ceildiv (shape_nhwc [1 ], off_h ))
33- shape .append (ceildiv (shape_nhwc [2 ], off_w ))
34- if packed_C :
35- shape .append (ceildiv (shape_nhwc [3 ], off_c ))
36- shape .extend (block_shape )
37- else :
38- shape .extend ([off_h , off_w , shape_nhwc [3 ]])
39- return shape
40-
41-
31+ # defines inner block shape: 8h8w32c
4232def get_block_shape ():
4333 return 8 , 8 , 32
4434
4535
36+ # defines inner filter block shape: 8i32o41
4637def get_filter_block_shape ():
4738 return 8 , 32 , 4
4839
4940
50- def get_packed_filter_layout (out_channel , in_channel , kernel_h , kernel_w ):
51- filter_Cio , filter_Ki , filter_Cii = get_filter_block_shape ()
41+ # input: locgical shape in nhwc layout
42+ # output: physical packed shape in nhw8h8w32c layout
43+ def get_packed_shape (logical_shape_nhwc ):
44+ assert len (logical_shape_nhwc ) == 4
45+ physical_shape_nhwc8h8w32c = [logical_shape_nhwc [0 ]]
46+ block_shape = get_block_shape ()
47+ off_h , off_w , off_c = block_shape
48+ physical_shape_nhwc8h8w32c .append (ceildiv (logical_shape_nhwc [1 ], off_h ))
49+ physical_shape_nhwc8h8w32c .append (ceildiv (logical_shape_nhwc [2 ], off_w ))
50+ physical_shape_nhwc8h8w32c .append (ceildiv (logical_shape_nhwc [3 ], off_c ))
51+ physical_shape_nhwc8h8w32c .extend (block_shape )
52+ return physical_shape_nhwc8h8w32c
53+
54+
55+ # input: physical packed shape in nhw8h8w32c layout
56+ # output: logical shape in nhwc layout
57+ def get_logical_shape (physical_shape_nhwc8h8w32c ):
58+ assert len (physical_shape_nhwc8h8w32c ) == 7
59+ logical_shape_nhwc = [physical_shape_nhwc8h8w32c [0 ]]
60+ logical_shape_nhwc .append (physical_shape_nhwc8h8w32c [1 ] * physical_shape_nhwc8h8w32c [4 ])
61+ logical_shape_nhwc .append (physical_shape_nhwc8h8w32c [2 ] * physical_shape_nhwc8h8w32c [5 ])
62+ logical_shape_nhwc .append (physical_shape_nhwc8h8w32c [3 ] * physical_shape_nhwc8h8w32c [6 ])
63+ return logical_shape_nhwc
64+
65+
66+ # input: logical shape in oihw layout
67+ # output: physical packed shape in oihw8i3204i layout
68+ def get_packed_filter_shape (logical_shape_oihw ):
69+ assert len (logical_shape_oihw ) == 4
70+ filter_block_shape = get_filter_block_shape ()
71+ filter_Cio , filter_Ki , filter_Cii = filter_block_shape
5272 filter_Ci = filter_Cio * filter_Cii
53- return (
54- int (ceildiv (out_channel , filter_Ki )),
55- int (ceildiv (in_channel , filter_Ci )),
56- kernel_h ,
57- kernel_w ,
58- filter_Cio ,
59- filter_Ki ,
60- filter_Cii ,
61- )
73+ physical_shape_oihw8i32o4i = []
74+ physical_shape_oihw8i32o4i .append (int (ceildiv (logical_shape_oihw [0 ], filter_Ki )))
75+ physical_shape_oihw8i32o4i .append (int (ceildiv (logical_shape_oihw [1 ], filter_Ci )))
76+ physical_shape_oihw8i32o4i .append (logical_shape_oihw [2 ])
77+ physical_shape_oihw8i32o4i .append (logical_shape_oihw [3 ])
78+ physical_shape_oihw8i32o4i .extend (filter_block_shape )
79+ return physical_shape_oihw8i32o4i
6280
6381
6482def build_and_run (inputs , func , target , target_host , * args , ** kwargs ):
@@ -95,26 +113,10 @@ def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, o
95113 )
96114
97115
98- def verify_conv2d (output , ref_output , dtype ):
99- # nhwc8h8w32c
100- if len (output .shape ) == 7 :
101- # nhwc8h8w32c -> nhwc
102- output = output .transpose (0 , 1 , 4 , 2 , 5 , 3 , 6 ).reshape (
103- output .shape [0 ],
104- output .shape [1 ] * output .shape [4 ],
105- output .shape [2 ] * output .shape [5 ],
106- output .shape [3 ] * output .shape [6 ],
107- )
108-
109- # nhwhwc
110- else :
111- # nhwhwc -> nhwc
112- output = output .transpose (0 , 1 , 3 , 2 , 4 , 5 ).reshape (
113- output .shape [0 ],
114- output .shape [1 ] * output .shape [3 ],
115- output .shape [2 ] * output .shape [4 ],
116- output .shape [5 ],
117- )
116+ def conv2d_verify (output , ref_output , dtype ):
117+ # nhwc8h8w32c -> nhwc
118+ logical_output_shape = get_logical_shape (output .shape )
119+ output = output .transpose (0 , 1 , 4 , 2 , 5 , 3 , 6 ).reshape (logical_output_shape )
118120
119121 # slice output to match ref_output shape
120122 # e.g. 8x8 spatial 3x3 filter = 6x6 ref output
@@ -131,3 +133,64 @@ def verify_conv2d(output, ref_output, dtype):
131133 elif dtype == "float32" :
132134 tol = {"rtol" : 1e-4 , "atol" : 2e-4 }
133135 tvm .testing .assert_allclose (output , ref_output , ** tol )
136+
137+
138+ def conv2d_compute (X , filt , pad , stride , dilation ):
139+ block_shape = get_block_shape ()
140+ block_H , block_W , block_C = block_shape
141+ filter_Cio , filter_Ki , filter_Cii = get_filter_block_shape ()
142+ filter_Ci = filter_Cio * filter_Cii
143+
144+ shape_filter = filt .shape
145+ kernel_size = tuple (shape_filter [2 :4 ])
146+ out_channels = shape_filter [0 ] * shape_filter [5 ]
147+
148+ logical_input_shape = get_logical_shape (X .shape )
149+ logical_output_shape = get_conv2d_nhwc_shape (
150+ logical_input_shape ,
151+ kernel_size ,
152+ stride ,
153+ pad ,
154+ dilation ,
155+ out_channels ,
156+ )
157+
158+ output_shape = get_packed_shape (logical_output_shape )
159+ n , ho , wo , ko , hi , wi , ki = output_shape
160+ rh = te .reduce_axis ((0 , kernel_size [0 ]), name = "rh" )
161+ rw = te .reduce_axis ((0 , kernel_size [1 ]), name = "rw" )
162+ rc = te .reduce_axis ((0 , logical_input_shape [3 ]), name = "rc" )
163+
164+ def compute (n , ho , wo , ko , hi , wi , ki ):
165+ h = ho * block_H + hi
166+ h_contig = h * stride [0 ] + rh
167+ h_block_id = h_contig // block_H
168+ h_block_offset = h_contig % block_H
169+
170+ w = wo * block_W + wi
171+ w_contig = w * stride [1 ] + rw
172+ w_block_id = w_contig // block_W
173+ w_block_offset = w_contig % block_W
174+
175+ c_block_id = rc // block_C
176+ c_block_offset = rc % block_C
177+
178+ rco = rc // filter_Ci
179+ rcio = (rc % filter_Ci ) // filter_Cii
180+ rcii = rc % filter_Cii
181+
182+ return te .sum (
183+ X [
184+ n ,
185+ h_block_id ,
186+ w_block_id ,
187+ c_block_id ,
188+ h_block_offset ,
189+ w_block_offset ,
190+ c_block_offset ,
191+ ]
192+ * filt [ko , rco , rh , rw , rcio , ki , rcii ],
193+ axis = [rh , rw , rc ],
194+ )
195+
196+ return output_shape , compute
0 commit comments