@@ -132,6 +132,199 @@ def transformed_simple_compute(
132132 C [tx , 15 ] = B [1 , tx , 0 ] + T .float32 (1 )
133133
134134
135+ @T .prim_func
136+ def three_stage_compute (A : T .Buffer [(16 , 16 ), "float32" ], D : T .Buffer [(16 , 16 ), "float32" ]):
137+ for tx in T .thread_binding (0 , 16 , thread = "threadIdx.x" ):
138+ for i in T .serial (
139+ 0 ,
140+ 16 ,
141+ annotations = {
142+ "software_pipeline_stage" : [0 , 1 , 2 ],
143+ "software_pipeline_order" : [0 , 1 , 2 ],
144+ },
145+ ):
146+ with T .block ():
147+ T .reads (A [tx , i ])
148+ T .writes (D [tx , i ])
149+ B = T .alloc_buffer ((16 , 1 ), dtype = "float32" , scope = "shared" )
150+ C = T .alloc_buffer ((16 , 1 ), dtype = "float32" , scope = "shared" )
151+ with T .block ():
152+ T .reads (A [tx , i ])
153+ T .writes (B [tx , 0 ])
154+ B [tx , 0 ] = A [tx , i ] * T .float32 (2 )
155+ with T .block ():
156+ T .reads (B [tx , 0 ])
157+ T .writes (C [tx , 0 ])
158+ C [tx , 0 ] = A [tx , 0 ] + T .float32 (2 )
159+ with T .block ():
160+ T .reads (C [tx , 0 ])
161+ T .writes (D [tx , i ])
162+ D [tx , i ] = C [tx , 0 ] + T .float32 (1 )
163+
164+
165+ @T .prim_func
166+ def transformed_three_stage_compute (
167+ A : T .Buffer [(16 , 16 ), "float32" ], D : T .Buffer [(16 , 16 ), "float32" ]
168+ ) -> None :
169+ for tx in T .thread_binding (16 , thread = "threadIdx.x" ):
170+ with T .block ():
171+ T .reads (A [tx , 0 :16 ])
172+ T .writes (D [tx , 0 :16 ])
173+ B = T .alloc_buffer ([2 , 16 , 1 ], dtype = "float32" , scope = "shared" )
174+ C = T .alloc_buffer ([2 , 16 , 1 ], dtype = "float32" , scope = "shared" )
175+ with T .block ():
176+ T .reads (A [tx , 0 :2 ], B [0 :2 , tx , 0 ])
177+ T .writes (B [0 :2 , tx , 0 ], C [0 :2 , tx , 0 ])
178+ for i in T .unroll (2 ):
179+ with T .block ():
180+ T .reads (A [tx , i ])
181+ T .writes (B [0 :2 , tx , 0 ])
182+ B [i , tx , 0 ] = A [tx , i ] * T .float32 (2 )
183+ with T .block ():
184+ T .where (1 <= i )
185+ T .reads (B [0 :2 , tx , 0 ])
186+ T .writes (C [0 :2 , tx , 0 ])
187+ C [(i + 1 ) % 2 , tx , 0 ] = A [tx , 0 ] + T .float32 (2 )
188+ with T .block ():
189+ T .reads (A [tx , 2 :16 ], B [0 :2 , tx , 0 ], C [0 :2 , tx , 0 ])
190+ T .writes (B [0 :2 , tx , 0 ], C [0 :2 , tx , 0 ], D [tx , 0 :14 ])
191+ for i in T .serial (14 ):
192+ with T .block ():
193+ T .reads (A [tx , i + 2 ])
194+ T .writes (B [0 :2 , tx , 0 ])
195+ B [i % 2 , tx , 0 ] = A [tx , i + 2 ] * T .float32 (2 )
196+ with T .block ():
197+ T .reads (B [0 :2 , tx , 0 ])
198+ T .writes (C [0 :2 , tx , 0 ])
199+ C [(i + 1 ) % 2 , tx , 0 ] = A [tx , 0 ] + T .float32 (2 )
200+ with T .block ():
201+ T .reads (C [0 :2 , tx , 0 ])
202+ T .writes (D [tx , i ])
203+ D [tx , i ] = C [i % 2 , tx , 0 ] + T .float32 (1 )
204+ with T .block ():
205+ T .reads (B [0 :2 , tx , 0 ], C [0 :2 , tx , 0 ])
206+ T .writes (C [0 :2 , tx , 0 ], D [tx , 14 :16 ])
207+ for i in T .unroll (2 ):
208+ with T .block ():
209+ T .where (i < 1 )
210+ T .reads (B [0 :2 , tx , 0 ])
211+ T .writes (C [0 :2 , tx , 0 ])
212+ C [(i + 1 ) % 2 , tx , 0 ] = A [tx , 0 ] + T .float32 (2 )
213+ with T .block ():
214+ T .reads (C [0 :2 , tx , 0 ])
215+ T .writes (D [tx , i + 14 ])
216+ D [tx , i + 14 ] = C [i , tx , 0 ] + T .float32 (1 )
217+
218+
219+ @T .prim_func
220+ def dag_interleaving (
221+ A : T .Buffer [(16 , 16 ), "float32" ],
222+ B : T .Buffer [(16 , 16 ), "float32" ],
223+ C : T .Buffer [(16 , 16 ), "float32" ],
224+ ) -> None :
225+ for tx in T .thread_binding (0 , 16 , thread = "threadIdx.x" ):
226+ for i in T .serial (
227+ 0 ,
228+ 16 ,
229+ annotations = {
230+ "software_pipeline_stage" : [0 , 0 , 0 , 0 , 1 ],
231+ "software_pipeline_order" : [0 , 2 , 1 , 3 , 4 ],
232+ },
233+ ):
234+ with T .block ():
235+ T .reads (A [tx , i ])
236+ T .writes (C [tx , i ])
237+ AS = T .alloc_buffer ((16 , 1 ), dtype = "float32" , scope = "shared" )
238+ BS = T .alloc_buffer ((16 , 1 ), dtype = "float32" , scope = "shared" )
239+ AL = T .alloc_buffer ((1 , 1 ), dtype = "float32" , scope = "local" )
240+ BL = T .alloc_buffer ((1 , 1 ), dtype = "float32" , scope = "local" )
241+ with T .block ():
242+ T .reads (A [tx , i ])
243+ T .writes (AS [tx , 0 ])
244+ AS [tx , 0 ] = A [tx , i ] * T .float32 (2 )
245+ with T .block ():
246+ T .reads (AS [tx , 0 ])
247+ T .writes (AL [0 , 0 ])
248+ AL [0 , 0 ] = AS [tx , 0 ]
249+ with T .block ():
250+ T .reads (B [tx , i ])
251+ T .writes (BS [tx , 0 ])
252+ BS [tx , 0 ] = B [tx , i ] + T .float32 (2 )
253+ with T .block ():
254+ T .reads (BS [tx , 0 ])
255+ T .writes (BL [0 , 0 ])
256+ BL [0 , 0 ] = BS [tx , 0 ]
257+ with T .block ():
258+ T .reads (AL [0 , 0 ], BL [0 , 0 ])
259+ T .writes (C [tx , i ])
260+ C [tx , i ] = AL [0 , 0 ] * BL [0 , 0 ]
261+
262+
263+ @T .prim_func
264+ def transformed_dag_interleaving (
265+ A : T .Buffer [(16 , 16 ), "float32" ],
266+ B : T .Buffer [(16 , 16 ), "float32" ],
267+ C : T .Buffer [(16 , 16 ), "float32" ],
268+ ) -> None :
269+ for tx in T .thread_binding (16 , thread = "threadIdx.x" ):
270+ with T .block ():
271+ T .reads (A [tx , 0 :16 ], B [tx , 0 :16 ])
272+ T .writes (C [tx , 0 :16 ])
273+ AS = T .alloc_buffer ([16 , 1 ], dtype = "float32" , scope = "shared" )
274+ BS = T .alloc_buffer ([16 , 1 ], dtype = "float32" , scope = "shared" )
275+ AL = T .alloc_buffer ([2 , 1 , 1 ], dtype = "float32" , scope = "local" )
276+ BL = T .alloc_buffer ([2 , 1 , 1 ], dtype = "float32" , scope = "local" )
277+ with T .block ():
278+ T .reads (A [tx , 0 ], B [tx , 0 ], AS [tx , 0 ], BS [tx , 0 ])
279+ T .writes (AS [tx , 0 ], BS [tx , 0 ], AL [0 , 0 , 0 ], BL [0 , 0 , 0 ])
280+ with T .block ():
281+ T .reads (A [tx , 0 ])
282+ T .writes (AS [tx , 0 ])
283+ AS [tx , 0 ] = A [tx , 0 ] * T .float32 (2 )
284+ with T .block ():
285+ T .reads (B [tx , 0 ])
286+ T .writes (BS [tx , 0 ])
287+ BS [tx , 0 ] = B [tx , 0 ] + T .float32 (2 )
288+ with T .block ():
289+ T .reads (AS [tx , 0 ])
290+ T .writes (AL [0 , 0 , 0 ])
291+ AL [0 , 0 , 0 ] = AS [tx , 0 ]
292+ with T .block ():
293+ T .reads (BS [tx , 0 ])
294+ T .writes (BL [0 , 0 , 0 ])
295+ BL [0 , 0 , 0 ] = BS [tx , 0 ]
296+ with T .block ():
297+ T .reads (
298+ A [tx , 1 :16 ], B [tx , 1 :16 ], AS [tx , 0 ], BS [tx , 0 ], AL [0 :2 , 0 , 0 ], BL [0 :2 , 0 , 0 ]
299+ )
300+ T .writes (AS [tx , 0 ], BS [tx , 0 ], AL [0 :2 , 0 , 0 ], BL [0 :2 , 0 , 0 ], C [tx , 0 :15 ])
301+ for i in T .serial (15 ):
302+ with T .block ():
303+ T .reads (A [tx , i + 1 ])
304+ T .writes (AS [tx , 0 ])
305+ AS [tx , 0 ] = A [tx , i + 1 ] * T .float32 (2 )
306+ with T .block ():
307+ T .reads (B [tx , i + 1 ])
308+ T .writes (BS [tx , 0 ])
309+ BS [tx , 0 ] = B [tx , i + 1 ] + T .float32 (2 )
310+ with T .block ():
311+ T .reads (AS [tx , 0 ])
312+ T .writes (AL [(i + 1 ) % 2 , 0 , 0 ])
313+ AL [(i + 1 ) % 2 , 0 , 0 ] = AS [tx , 0 ]
314+ with T .block ():
315+ T .reads (BS [tx , 0 ])
316+ T .writes (BL [(i + 1 ) % 2 , 0 , 0 ])
317+ BL [(i + 1 ) % 2 , 0 , 0 ] = BS [tx , 0 ]
318+ with T .block ():
319+ T .reads (AL [i % 2 , 0 , 0 ], BL [i % 2 , 0 , 0 ])
320+ T .writes (C [tx , i ])
321+ C [tx , i ] = AL [i % 2 , 0 , 0 ] * BL [i % 2 , 0 , 0 ]
322+ with T .block ():
323+ T .reads (AL [1 , 0 , 0 ], BL [1 , 0 , 0 ])
324+ T .writes (C [tx , 15 ])
325+ C [tx , 15 ] = AL [1 , 0 , 0 ] * BL [1 , 0 , 0 ]
326+
327+
135328@T .prim_func
136329def nested_pipeline_simple (
137330 A : T .Buffer [(16 , 16 , 16 ), "float32" ], C : T .Buffer [(16 , 16 , 16 ), "float32" ]
@@ -792,6 +985,14 @@ def test_trivial_pipeline():
792985 _check (trivial_pipeline , transformed_trivial_pipeline )
793986
794987
988+ def test_three_stage_compute ():
989+ _check (three_stage_compute , transformed_three_stage_compute )
990+
991+
992+ def test_dag_interleaving ():
993+ _check (dag_interleaving , transformed_dag_interleaving )
994+
995+
795996def test_nest_pipeline_simple ():
796997 _check (nested_pipeline_simple , transformed_nested_pipeline_simple )
797998
0 commit comments