Skip to content

Commit 9c40f2e

Browse files
authored
[TIR] Enhance software pipeline validation and fix predicate of epilogue (#11106)
* Fix pipeline validation * fix predicate * Update test_tir_transform_inject_software_pipeline.py * Update inject_software_pipeline.cc
1 parent 4015916 commit 9c40f2e

File tree

2 files changed

+266
-9
lines changed

2 files changed

+266
-9
lines changed

src/tir/transforms/inject_software_pipeline.cc

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,10 @@ class PipelineRewriter : public StmtExprMutator {
534534
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var);
535535
} else {
536536
// normalize loop range
537-
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min));
537+
PrimExpr delta = start - pipeline_loop_->min;
538+
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta);
539+
Var loop_iter = Downcast<Var>(new_loop_var);
540+
inbound = Substitute(inbound, Map<Var, PrimExpr>{{loop_iter, loop_iter + delta}});
538541
}
539542
new_block = Downcast<Block>(Substitute(new_block, subst_map));
540543
stmts.push_back(BlockRealize({}, inbound, new_block));
@@ -570,6 +573,40 @@ class PipelineRewriter : public StmtExprMutator {
570573
Array<Block> ordered_stmts_;
571574
};
572575

576+
/*!
577+
* \brief Build the dependency graph among a array of blocks.
578+
* \param[in] blocks The array of blocks.
579+
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
580+
* destination.
581+
* \param[out] dep_dst2src Optional, a map to store dependency edges from the
582+
* destination to the source.
583+
*/
584+
void BuildDependencyGraph(
585+
const Array<Block>& blocks,
586+
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst,
587+
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) {
588+
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
589+
590+
for (const Block& block : blocks) {
591+
for (const BufferRegion& read : block->reads) {
592+
auto it = buffer_writers.find(read->buffer->data);
593+
if (it != buffer_writers.end()) {
594+
for (const Block& writer : it->second) {
595+
if (dep_src2dst != nullptr) {
596+
(*dep_src2dst)[writer].push_back(block);
597+
}
598+
if (dep_dst2src != nullptr) {
599+
(*dep_dst2src)[block].push_back(writer);
600+
}
601+
}
602+
}
603+
}
604+
for (const BufferRegion& write : block->writes) {
605+
buffer_writers[write->buffer->data].push_back(block);
606+
}
607+
}
608+
}
609+
573610
class PipelineInjector : private StmtExprMutator {
574611
public:
575612
static Stmt Inject(const PrimFunc& func) {
@@ -587,24 +624,43 @@ class PipelineInjector : private StmtExprMutator {
587624

588625
/*!
589626
* \brief Check the pipeline satisfies the following conditions:
590-
* 1) No conflicting order: The order of each statement should be unique.
591-
* 2) No reordering with the same stage: Statements in the same stage are not allowed to be
592-
* reordered.
627+
* 1. No conflicting order: The order of each statement should be unique.
628+
* 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for
629+
* dependency (e.g. read-after-write) from statement A to statement B, it requires:
630+
* case 1: stage(A) < stage(B)
631+
* case 2: stage(A) == stage(B) and order(A) < order(B)
593632
*/
594633
void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) {
595634
std::unordered_set<int> used_orders;
596635
std::unordered_map<int, int> stage_max_order;
636+
std::unordered_map<int, const Block*> order_to_block;
637+
std::unordered_map<const Block*, int> block_to_stage;
597638
for (const Block& block : original_order) {
598639
const auto& stmt_info = pipeline_info.at(block);
599-
int stage = stmt_info.stage;
600640
int order = stmt_info.order;
601641
CHECK(!used_orders.count(order))
602642
<< "ValueError: Two statements in the software pipeline cannot have the same order";
603643
used_orders.insert(order);
604-
CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order)
605-
<< "ValueError: Statements in the same stage of the software pipeline must have "
606-
"increasing order.";
607-
stage_max_order[stage] = order;
644+
}
645+
646+
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst;
647+
BuildDependencyGraph(original_order, &dep_src2dst, nullptr);
648+
649+
for (const auto& pair : dep_src2dst) {
650+
const Block& src = pair.first;
651+
const auto& src_info = pipeline_info.at(src);
652+
const Array<Block>& dsts = pair.second;
653+
for (const Block& dst : dsts) {
654+
const auto& dst_info = pipeline_info.at(dst);
655+
CHECK_LE(src_info.stage, dst_info.stage)
656+
<< "ValueError: statement " << dst << " in stage " << dst_info.stage
657+
<< " cannot depends on statement " << src << " in a later stage " << src_info.stage;
658+
if (src_info.stage == dst_info.stage) {
659+
CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer "
660+
"access dependency in the same stage of the "
661+
"software pipeline cannot be reordered";
662+
}
663+
}
608664
}
609665
}
610666

tests/python/unittest/test_tir_transform_inject_software_pipeline.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,199 @@ def transformed_simple_compute(
132132
C[tx, 15] = B[1, tx, 0] + T.float32(1)
133133

134134

135+
@T.prim_func
136+
def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
137+
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
138+
for i in T.serial(
139+
0,
140+
16,
141+
annotations={
142+
"software_pipeline_stage": [0, 1, 2],
143+
"software_pipeline_order": [0, 1, 2],
144+
},
145+
):
146+
with T.block():
147+
T.reads(A[tx, i])
148+
T.writes(D[tx, i])
149+
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
150+
C = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
151+
with T.block():
152+
T.reads(A[tx, i])
153+
T.writes(B[tx, 0])
154+
B[tx, 0] = A[tx, i] * T.float32(2)
155+
with T.block():
156+
T.reads(B[tx, 0])
157+
T.writes(C[tx, 0])
158+
C[tx, 0] = A[tx, 0] + T.float32(2)
159+
with T.block():
160+
T.reads(C[tx, 0])
161+
T.writes(D[tx, i])
162+
D[tx, i] = C[tx, 0] + T.float32(1)
163+
164+
165+
@T.prim_func
166+
def transformed_three_stage_compute(
167+
A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]
168+
) -> None:
169+
for tx in T.thread_binding(16, thread="threadIdx.x"):
170+
with T.block():
171+
T.reads(A[tx, 0:16])
172+
T.writes(D[tx, 0:16])
173+
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
174+
C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
175+
with T.block():
176+
T.reads(A[tx, 0:2], B[0:2, tx, 0])
177+
T.writes(B[0:2, tx, 0], C[0:2, tx, 0])
178+
for i in T.unroll(2):
179+
with T.block():
180+
T.reads(A[tx, i])
181+
T.writes(B[0:2, tx, 0])
182+
B[i, tx, 0] = A[tx, i] * T.float32(2)
183+
with T.block():
184+
T.where(1 <= i)
185+
T.reads(B[0:2, tx, 0])
186+
T.writes(C[0:2, tx, 0])
187+
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
188+
with T.block():
189+
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
190+
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
191+
for i in T.serial(14):
192+
with T.block():
193+
T.reads(A[tx, i + 2])
194+
T.writes(B[0:2, tx, 0])
195+
B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
196+
with T.block():
197+
T.reads(B[0:2, tx, 0])
198+
T.writes(C[0:2, tx, 0])
199+
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
200+
with T.block():
201+
T.reads(C[0:2, tx, 0])
202+
T.writes(D[tx, i])
203+
D[tx, i] = C[i % 2, tx, 0] + T.float32(1)
204+
with T.block():
205+
T.reads(B[0:2, tx, 0], C[0:2, tx, 0])
206+
T.writes(C[0:2, tx, 0], D[tx, 14:16])
207+
for i in T.unroll(2):
208+
with T.block():
209+
T.where(i < 1)
210+
T.reads(B[0:2, tx, 0])
211+
T.writes(C[0:2, tx, 0])
212+
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
213+
with T.block():
214+
T.reads(C[0:2, tx, 0])
215+
T.writes(D[tx, i + 14])
216+
D[tx, i + 14] = C[i, tx, 0] + T.float32(1)
217+
218+
219+
@T.prim_func
220+
def dag_interleaving(
221+
A: T.Buffer[(16, 16), "float32"],
222+
B: T.Buffer[(16, 16), "float32"],
223+
C: T.Buffer[(16, 16), "float32"],
224+
) -> None:
225+
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
226+
for i in T.serial(
227+
0,
228+
16,
229+
annotations={
230+
"software_pipeline_stage": [0, 0, 0, 0, 1],
231+
"software_pipeline_order": [0, 2, 1, 3, 4],
232+
},
233+
):
234+
with T.block():
235+
T.reads(A[tx, i])
236+
T.writes(C[tx, i])
237+
AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
238+
BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
239+
AL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
240+
BL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
241+
with T.block():
242+
T.reads(A[tx, i])
243+
T.writes(AS[tx, 0])
244+
AS[tx, 0] = A[tx, i] * T.float32(2)
245+
with T.block():
246+
T.reads(AS[tx, 0])
247+
T.writes(AL[0, 0])
248+
AL[0, 0] = AS[tx, 0]
249+
with T.block():
250+
T.reads(B[tx, i])
251+
T.writes(BS[tx, 0])
252+
BS[tx, 0] = B[tx, i] + T.float32(2)
253+
with T.block():
254+
T.reads(BS[tx, 0])
255+
T.writes(BL[0, 0])
256+
BL[0, 0] = BS[tx, 0]
257+
with T.block():
258+
T.reads(AL[0, 0], BL[0, 0])
259+
T.writes(C[tx, i])
260+
C[tx, i] = AL[0, 0] * BL[0, 0]
261+
262+
263+
@T.prim_func
264+
def transformed_dag_interleaving(
265+
A: T.Buffer[(16, 16), "float32"],
266+
B: T.Buffer[(16, 16), "float32"],
267+
C: T.Buffer[(16, 16), "float32"],
268+
) -> None:
269+
for tx in T.thread_binding(16, thread="threadIdx.x"):
270+
with T.block():
271+
T.reads(A[tx, 0:16], B[tx, 0:16])
272+
T.writes(C[tx, 0:16])
273+
AS = T.alloc_buffer([16, 1], dtype="float32", scope="shared")
274+
BS = T.alloc_buffer([16, 1], dtype="float32", scope="shared")
275+
AL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local")
276+
BL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local")
277+
with T.block():
278+
T.reads(A[tx, 0], B[tx, 0], AS[tx, 0], BS[tx, 0])
279+
T.writes(AS[tx, 0], BS[tx, 0], AL[0, 0, 0], BL[0, 0, 0])
280+
with T.block():
281+
T.reads(A[tx, 0])
282+
T.writes(AS[tx, 0])
283+
AS[tx, 0] = A[tx, 0] * T.float32(2)
284+
with T.block():
285+
T.reads(B[tx, 0])
286+
T.writes(BS[tx, 0])
287+
BS[tx, 0] = B[tx, 0] + T.float32(2)
288+
with T.block():
289+
T.reads(AS[tx, 0])
290+
T.writes(AL[0, 0, 0])
291+
AL[0, 0, 0] = AS[tx, 0]
292+
with T.block():
293+
T.reads(BS[tx, 0])
294+
T.writes(BL[0, 0, 0])
295+
BL[0, 0, 0] = BS[tx, 0]
296+
with T.block():
297+
T.reads(
298+
A[tx, 1:16], B[tx, 1:16], AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0]
299+
)
300+
T.writes(AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0], C[tx, 0:15])
301+
for i in T.serial(15):
302+
with T.block():
303+
T.reads(A[tx, i + 1])
304+
T.writes(AS[tx, 0])
305+
AS[tx, 0] = A[tx, i + 1] * T.float32(2)
306+
with T.block():
307+
T.reads(B[tx, i + 1])
308+
T.writes(BS[tx, 0])
309+
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
310+
with T.block():
311+
T.reads(AS[tx, 0])
312+
T.writes(AL[(i + 1) % 2, 0, 0])
313+
AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
314+
with T.block():
315+
T.reads(BS[tx, 0])
316+
T.writes(BL[(i + 1) % 2, 0, 0])
317+
BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
318+
with T.block():
319+
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
320+
T.writes(C[tx, i])
321+
C[tx, i] = AL[i % 2, 0, 0] * BL[i % 2, 0, 0]
322+
with T.block():
323+
T.reads(AL[1, 0, 0], BL[1, 0, 0])
324+
T.writes(C[tx, 15])
325+
C[tx, 15] = AL[1, 0, 0] * BL[1, 0, 0]
326+
327+
135328
@T.prim_func
136329
def nested_pipeline_simple(
137330
A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
@@ -792,6 +985,14 @@ def test_trivial_pipeline():
792985
_check(trivial_pipeline, transformed_trivial_pipeline)
793986

794987

988+
def test_three_stage_compute():
989+
_check(three_stage_compute, transformed_three_stage_compute)
990+
991+
992+
def test_dag_interleaving():
993+
_check(dag_interleaving, transformed_dag_interleaving)
994+
995+
795996
def test_nest_pipeline_simple():
796997
_check(nested_pipeline_simple, transformed_nested_pipeline_simple)
797998

0 commit comments

Comments
 (0)