-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Bugfix] Disable SingleEnvThreadVerifier #16361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Disable SingleEnvThreadVerifier #16361
Conversation
|
Oh, interesting. I initially added the check due to encountering this error in Can we add a test case that includes multiple @T.prim_func
def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
tx_A = T.launch_thread("threadIdx.x", 16)
tx_B = T.launch_thread("threadIdx.x", 16)
B[tx_B] = A[tx_A]
tvm.build(func, target="cuda") |
|
Hey thanks for getting back to me @Lunderberg! The fundamental reason is that this pass is invoked in TIR schedule creation to validate if TIR is well-formed, and at that stage, env variables are allowed to duplicate. Below is an example of a cumulative sum operator we use when enabling MegaBlock style MoE Mixtral serving. (See also PR: mlc-ai/mlc-llm#1529). It's the first step in computing @T.prim_func(private=True)
def main(var_A: T.handle, var_T_add: T.handle, seq_len: T.int64):
T.func_attr({"tir.noalias": T.bool(True)})
A = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32")
T_add = T.match_buffer(var_T_add, (seq_len * T.int64(8),), "int32")
# with T.block("root"):
T_expand_dims = T.alloc_buffer((T.int64(1), seq_len * T.int64(8)), "int32")
output_buf = T.alloc_buffer((T.int64(1), seq_len * T.int64(8)), "int32", align=8)
T_squeeze = T.alloc_buffer((seq_len * T.int64(8),), "int32")
for ax0, ax1 in T.grid(T.int64(1), seq_len * T.int64(8)):
with T.block("T_expand_dims"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax1])
T.writes(T_expand_dims[v_ax0, v_ax1])
T_expand_dims[v_ax0, v_ax1] = A[v_ax1]
with T.block("exclusive_scan"):
T.reads()
T.writes()
if seq_len * T.int64(8) == T.int64(0):
blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
if blockIdx_x < T.int64(1):
T.evaluate(0)
else:
with T.launch_thread("threadIdx.x", T.int64(1024)) as threadIdx_x:
blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (seq_len * T.int64(8) + T.int64(1023)) // T.int64(1024)))
blockIdx_y = T.launch_thread("blockIdx.y", T.int64(1))
if blockIdx_x * T.int64(1024) + threadIdx_x < seq_len * T.int64(8):
output_buf[(blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) % (seq_len * T.int64(8))] = T_expand_dims[(blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) % (seq_len * T.int64(8))]
for i in range(T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8)))))):
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * T.int64(8) + (T.int64(1024) * T.shift_left(T.int64(2), i) - T.int64(1))) // (T.int64(1024) * T.shift_left(T.int64(2), i)))))
blockIdx_y = T.launch_thread("blockIdx.y", T.int64(1))
start = T.allocate([T.int64(1)], "int64", "local")
middle = T.allocate([T.int64(1)], "int64", "local")
end = T.allocate([T.int64(1)], "int64", "local")
start_1 = T.Buffer((1,), "int64", data=start, scope="local")
start_1[T.int64(0)] = T.shift_left(T.int64(2), i) * T.Cast("int64", blockIdx_x * 1024 + threadIdx_x)
if start_1[T.int64(0)] < seq_len * T.int64(8):
middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
middle_1[T.int64(0)] = start_1[T.int64(0)] + T.shift_left(T.int64(2), i) // T.int64(2)
end_1 = T.Buffer((1,), "int64", data=end, scope="local")
end_1[T.int64(0)] = T.min(start_1[T.int64(0)] + T.shift_left(T.int64(2), i), seq_len * T.int64(8))
if middle_1[T.int64(0)] < seq_len * T.int64(8):
output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] + output_buf[(blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))]
with T.launch_thread("blockIdx.x", T.int64(1)) as blockIdx_x:
if blockIdx_x < T.int64(1):
output_buf[((blockIdx_x + T.int64(1)) * (seq_len * T.int64(8)) - T.int64(1)) // (seq_len * T.int64(8)), ((blockIdx_x + T.int64(1)) * (seq_len * T.int64(8)) - T.int64(1)) % (seq_len * T.int64(8))] = 0
for j in range(T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8)))))):
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * T.int64(8) + (T.int64(1024) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)) - T.int64(1))) // (T.int64(1024) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1))))))
blockIdx_y = T.launch_thread("blockIdx.y", T.int64(1))
start = T.allocate([T.int64(1)], "int64", "local")
middle = T.allocate([T.int64(1)], "int64", "local")
end = T.allocate([T.int64(1)], "int64", "local")
end_1 = T.allocate([T.int64(1)], "int32", "local")
start_1 = T.Buffer((1,), "int64", data=start, scope="local")
start_1[T.int64(0)] = T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)) * T.Cast("int64", blockIdx_x * 1024 + threadIdx_x)
if start_1[T.int64(0)] < seq_len * T.int64(8):
middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
middle_1[T.int64(0)] = start_1[T.int64(0)] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)) // T.int64(2)
end_2 = T.Buffer((1,), "int64", data=end, scope="local")
end_2[T.int64(0)] = T.min(start_1[T.int64(0)] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)), seq_len * T.int64(8))
if middle_1[T.int64(0)] < seq_len * T.int64(8):
end_3 = T.Buffer((1,), "int32", data=end_1, scope="local")
end_3[T.int64(0)] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))]
output_buf[(blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))]
output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] + end_3[T.int64(0)]
for ax0 in range(seq_len * T.int64(8)):
with T.block("T_squeeze"):
v_ax0 = T.axis.spatial(seq_len * T.int64(8), ax0)
T.reads(output_buf[T.int64(0), v_ax0])
T.writes(T_squeeze[v_ax0])
T_squeeze[v_ax0] = output_buf[T.int64(0), v_ax0]
for ax0 in range(seq_len * T.int64(8)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(seq_len * T.int64(8), ax0)
T.reads(A[v_ax0], T_squeeze[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + T_squeeze[v_ax0]As you could see from the IR, |
|
Ah, I see. It looks like this originates from using # Before
for i in range(16):
for j in range(16):
B[i,j] = A[i,j]
# First thread binding, still have 16*16 = 256 total iterations. B is
# set at all indices.
for i in T.thread_binding('threadIdx.x', 16):
for j in range(16):
B[i,j] = A[i,j]
# Second thread binding, now only have 16 iterations. B is only set
# for indices where i==j, along the diagonal.
for i in T.thread_binding('threadIdx.x', 16):
for j in T.thread_binding('threadIdx.x', 16):
B[i,j] = A[i,j] |
|
On CUDA, when fetching data from global memory to shared, there's something we use called cooperative fetching, where the data is divided evenly and all threads cooperatively fetch their own portions. To implement this logic in TVM, we have to bind
It's the default schedule strategy used in MetaSchedule's multi-level tiling and here's an example of this usecase: https://github.com/apache/tvm/blob/main/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py#L71 |
|
Thank you on the details for cooperative fetching, and I agree that this check should be removed to make sure the use case is preserved, at least for the time being. I think so long as we have a unit test showing the functionality, this PR can be merged. My worry wasn't with the overall sequence that produces cooperative fetching in the scheduled module, more that the individual step of producing a nested thread binding doesn't feel sound to me. The scheduling primitives shouldn't be able to impact the output values, only how the way in which the outputs are computed. Since there are definitely cases in which producing a nested thread binding changes the result, I was hoping there would be an alternative way to express cooperative fetching without requiring the nested thread binding. |
aaae5a2 to
78ff3f9
Compare
During TensorIR scheduling, the `IterVar`s that represent environment threads may duplicate, i.e. it is legal to have two env threads with the same name tag, which may fail the `SingleEnvThreadVerifier` check during schedule creation. This PR disables this check in this case. In the future, it may be worthwhile to bring it back against post-scheduling TIR.
78ff3f9 to
92be12a
Compare
@Lunderberg This is a good point! Actually when designing TensorIR, @spectrometerHBH has put lots of thoughts in it specifically for cooperative fetching, and he has a proof that nesting thread bindings won't change the semantics of a TensorIR during scheduling stage if it satisfies "compact dataflow" condition. Therefore, I wouldn't worry too much in this particular case. In fact, in a simplified non-nested case below, the TIR describes two separate kernels, but when creating a TIR schedule, it errors out because @T.prim_func
def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32):
T.func_attr({"tir.noalias": T.bool(True)})
A = T.match_buffer(var_A, (1, seq_len * 8), "int32")
B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8)
with T.block("exclusive_scan"):
T.reads()
T.writes()
s8: T.int32 = seq_len * 8
if s8 == 0:
blockIdx_x = T.launch_thread("blockIdx.x", 1)
else:
with T.launch_thread("threadIdx.x", 1024) as threadIdx_x:
blockIdx_x = T.launch_thread("blockIdx.x", T.ceildiv(s8, 1024))
i: T.int32 = blockIdx_x * 1024 + threadIdx_x
if i < s8:
B[i // s8, i % s8] = A[i // s8, i % s8]I added this as a testcase |
|
There's regression on unity CI https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/unity/868/pipeline, we will need this fix |
|
To quickly unblock the Unity CI, @vinx13 could you cherry-pick this commit to unity? |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, and thank you for adding the unit test!
@junrushao Ooh, nice! Is there a link to the proof? Asking because the test case below is a counter-example, where using #!/usr/bin/env python3
import tvm.testing
from tvm.script import tir as T
import numpy as np
def _run_mod(mod: tvm.IRModule) -> np.array:
target = "cuda"
dev = tvm.device(target)
np_A = np.ones([16, 16], "int32")
np_B = np.zeros([16, 16], "int32")
tvm_A = tvm.nd.array(np_A, dev)
tvm_B = tvm.nd.array(np_B, dev)
built = tvm.build(mod, target="cuda")
built(tvm_A, tvm_B)
# Result will have 1 for indices where the copy occurred, and 0
# for indices where the copy failed to occur.
return tvm_B.numpy()
def test_nested_thread_binding():
@T.prim_func
def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
# Block index just makes sure this is a valid cuda kernel,
# even before either i or j are bound to threadIdx.x.
block_index = T.launch_thread("blockIdx.x", 1)
for i, j in T.grid(16, 16):
with T.block("copy"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
sch = tvm.tir.Schedule(func)
i, j = sch.get_loops("copy")
res_0 = _run_mod(sch.mod)
sch.bind(i, "threadIdx.x")
res_1 = _run_mod(sch.mod)
sch.bind(j, "threadIdx.x")
res_2 = _run_mod(sch.mod)
# This assert passes, all values are copied from A to B
np.testing.assert_array_equal(res_0, res_1)
# This assert fails, values are only copied from A to B along the
# diagonal.
np.testing.assert_array_equal(res_0, res_2)
if __name__ == "__main__":
tvm.testing.main() |
|
I think what the check needs to do is to make sure the loops that are mapped to block vars can not have 2 of them bound to the same thread axis. The rationale is we have 2 iter spaces, one formulated by loops The block bindings describe how these 2 spaces are mapped and effectively the order of execution of block instances. If we have 2 separate loops bound to the same thread axis, then effectively the loop iter space of 2d I x J will be only 1dim I, which is wrong. What you can do is to
TBH I don't think the checks (and the proofs I somehow crafted some time ago) or so can cover all the cases and are rigorous. But at least in the case above it's very clear. |
Hmm. So for the cooperative fetching example that @junrushao mentioned earlier (link), it would be acceptable because the How would that be checked after
Completely agreed. My concern is that, because schedule primitives may be part of automatic performance tuning, the resulting schedule could produce results that are faster, but incorrect. My goal would be to have a way to throw an error for my invalid test case, without preventing the cooperative fetching case. Whether that would be easier by having a different sequence of primitives to arrive at cooperative fetching, or a new analysis pathway to forbid the |
During TensorIR scheduling, the
IterVars that represent environment threads may duplicate, i.e. it is legal to have two env threads with the same name tag, which may fail theSingleEnvThreadVerifiercheck during schedule creation. This PR disables this check in this case. In the future, it may be worthwhile to bring it back against post-scheduling TIR.It's related to this commit. CC: @jinhongyii @Lunderberg