Skip to content

Commit befdc4e

Browse files
authored
[Fix][TIR] LowerCrossThreadReduction with write-back predicate (#14199)
Prior to this PR, the cross-thread reduction lowering pass does not add a store predicate to the write-back block. This is in consideration that for a certain write-back buffer position, all values being stored (by all the threads) in the write-back block are the same. Since all threads are writing the same value, we were assuming that not having a write-back block predicate is fine, because the result will not be wrong in any way. However, recently we noticed that some GPU backend compiler will capture this behavior (multiple threads writing a same position) as a race condition and thus throw compilation error. The compiler does not take the fact that all values being stored are the same, and insist on complaining. This means that we will still need the write-back block predicate to make things work. And this PR does this change. I have done integration tests locally to make sure that the generated kernels is right and produces the right results numerically.
1 parent baedf7f commit befdc4e

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/tir/transforms/lower_cross_thread_reduction.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,15 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, //
407407
BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices));
408408
wb_regions.push_back(BufferRegion(wb_buffers[i], region));
409409
}
410+
PrimExpr wb_predicate = const_true();
411+
for (const ForNode* loop : reduction_loops) {
412+
if (loop->thread_binding.defined()) {
413+
wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
414+
}
415+
}
410416
stmts.push_back(BlockRealize(
411417
/*iter_values=*/std::move(bindings),
412-
/*predicate=*/const_true(),
418+
/*predicate=*/wb_predicate,
413419
/*block=*/
414420
Block(/*iter_vars=*/std::move(iter_vars),
415421
/*reads=*/std::move(ct_buffer_regions),

tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None:
9393
)
9494
with T.block("B_write_back"):
9595
vi = T.axis.S(128, i)
96+
T.where(ki == 0)
9697
T.reads([reduce_temp0[0]])
9798
T.writes([B[vi]])
9899
B[vi] = reduce_temp0[0]
@@ -136,6 +137,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None:
136137
)
137138
with T.block("B_write_back"):
138139
vi = T.axis.spatial(128, i)
140+
T.where(k == 0)
139141
T.reads([reduce_temp0[0]])
140142
T.writes([B[vi]])
141143
B[vi] = reduce_temp0[0]
@@ -183,6 +185,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None:
183185
)
184186
with T.block("B_write_back"):
185187
vi = T.axis.spatial(128, i)
188+
T.where(ko == 0 and ki == 0)
186189
T.reads([reduce_temp0[0]])
187190
T.writes([B[vi]])
188191
B[vi] = reduce_temp0[0]
@@ -264,6 +267,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No
264267
)
265268
with T.block("B_write_back"):
266269
vi = T.axis.spatial(16, i)
270+
T.where(k0o == 0)
267271
T.reads([reduce_temp0[0]])
268272
T.writes([B[vi]])
269273
B[vi] = reduce_temp0[0]
@@ -326,6 +330,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
326330
)
327331
with T.block("B_write_back"):
328332
vi = T.axis.spatial(128, i)
333+
T.where(ki == 0)
329334
T.reads([reduce_temp0[0]])
330335
T.writes([B[vi]])
331336
B[vi] = reduce_temp0[0]
@@ -428,6 +433,7 @@ def lowered_single_reduction_loop_with_block_predicate(
428433
)
429434
with T.block("T_softmax_maxelem_write_back"):
430435
i0_2 = T.axis.spatial(256, i0 + ax0)
436+
T.where(ax1_1 == 0)
431437
T.reads(cross_thread_0[0])
432438
T.writes(T_softmax_maxelem_shared[i0_2])
433439
T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
@@ -467,6 +473,7 @@ def lowered_single_reduction_loop_with_block_predicate(
467473
)
468474
with T.block("T_softmax_expsum_write_back"):
469475
i0_4 = T.axis.spatial(256, i0 + ax0)
476+
T.where(ax1_1 == 0)
470477
T.reads(cross_thread_1[0])
471478
T.writes(T_softmax_expsum_shared[i0_4])
472479
T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
@@ -636,6 +643,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None:
636643
)
637644
with T.block("B_write_back"):
638645
vi = T.axis.spatial(128, i)
646+
T.where(k == 0)
639647
T.reads([reduce_temp0[0]])
640648
T.writes([B[vi]])
641649
B[vi] = reduce_temp0[0]
@@ -676,6 +684,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None:
676684
with T.block("B_write_back"):
677685
T.reads([reduce_temp0[0]])
678686
T.writes([B[()]])
687+
T.where(k == 0)
679688
B[()] = reduce_temp0[0]
680689

681690

@@ -865,6 +874,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
865874
)
866875
with T.block("T_softmax_maxelem_write_back"):
867876
i0_2 = T.axis.spatial(256, i0)
877+
T.where(ax0_1 == 0)
868878
T.reads([reduce_temp0[0]])
869879
T.writes([T_softmax_maxelem_shared[i0_2]])
870880
T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
@@ -907,6 +917,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
907917
)
908918
with T.block("T_softmax_expsum_write_back"):
909919
i0_4 = T.axis.spatial(256, i0)
920+
T.where(ax0_1 == 0)
910921
T.reads([reduce_temp1[0]])
911922
T.writes([T_softmax_expsum_shared[i0_4]])
912923
T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
@@ -1018,6 +1029,7 @@ def lowered_argmax_split(
10181029
)
10191030
with T.block("argmax_write_back"):
10201031
i = T.axis.spatial(128, i0)
1032+
T.where(i1_1 == 0)
10211033
T.reads(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0])
10221034
T.writes(argmax_v0[i], argmax_v1[i])
10231035
argmax_v0[i] = cross_thread_argmax_v0[0]
@@ -1109,6 +1121,7 @@ def lowered_argmin_split_init_update_reordered(
11091121
)
11101122
with T.block("argmin_write_back"):
11111123
i = T.axis.spatial(128, i0)
1124+
T.where(i1_1 == 0)
11121125
T.reads(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0])
11131126
T.writes(argmin_v0[i], argmin_v1[i])
11141127
argmin_v0[i] = cross_thread_argmin_v0[0]
@@ -1227,6 +1240,7 @@ def lowered_layer_norm_tuple_sum(
12271240
)
12281241
with T.block("data_red_temp_write_back"):
12291242
ax0 = T.axis.spatial(128, i0_fused)
1243+
T.where(i1_1 == 0)
12301244
T.reads(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0])
12311245
T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0])
12321246
data_red_temp_v0[ax0] = cross_thread_data_red_temp_v0[0]

0 commit comments

Comments
 (0)