Skip to content

Conversation

@junrushao
Copy link
Member

During TensorIR scheduling, the IterVars that represent environment threads may duplicate, i.e. it is legal to have two env threads with the same name tag, which may fail the SingleEnvThreadVerifier check during schedule creation. This PR disables this check in this case. In the future, it may be worthwhile to bring it back against post-scheduling TIR.

It's related to this commit. CC: @jinhongyii @Lunderberg

@junrushao junrushao marked this pull request as ready for review January 7, 2024 05:32
@Lunderberg
Copy link
Contributor

Lunderberg commented Jan 7, 2024

Oh, interesting. I initially added the check due to encountering this error in tir.transform.LowerWarpMemory, concluding that the requirement for a single TIR variable per thread index was a TIR limitation and wanting to raise the error as early as possible.

Can we add a test case that includes multiple T.env_thread assignments within the same kernel, verifying that it can be lowered/built without error? That way, it would be clear that this behavior should be valid (so long as the multiple bindings all have the same extent, that is). I'm thinking something like the following:

@T.prim_func
def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
    tx_A = T.launch_thread("threadIdx.x", 16)
    tx_B = T.launch_thread("threadIdx.x", 16)

    B[tx_B] = A[tx_A]


tvm.build(func, target="cuda")

@junrushao
Copy link
Member Author

Hey thanks for getting back to me @Lunderberg!

The fundamental reason is that this pass is invoked in TIR schedule creation to validate if TIR is well-formed, and at that stage, env variables are allowed to duplicate. Below is an example of a cumulative sum operator we use when enabling MegaBlock style MoE Mixtral serving. (See also PR: mlc-ai/mlc-llm#1529). It's the first step in computing indptr and indices in CSR/COO sparse format.

@T.prim_func(private=True)
def main(var_A: T.handle, var_T_add: T.handle, seq_len: T.int64):
    T.func_attr({"tir.noalias": T.bool(True)})
    A = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32")
    T_add = T.match_buffer(var_T_add, (seq_len * T.int64(8),), "int32")
    # with T.block("root"):
    T_expand_dims = T.alloc_buffer((T.int64(1), seq_len * T.int64(8)), "int32")
    output_buf = T.alloc_buffer((T.int64(1), seq_len * T.int64(8)), "int32", align=8)
    T_squeeze = T.alloc_buffer((seq_len * T.int64(8),), "int32")
    for ax0, ax1 in T.grid(T.int64(1), seq_len * T.int64(8)):
        with T.block("T_expand_dims"):
            v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A[v_ax1])
            T.writes(T_expand_dims[v_ax0, v_ax1])
            T_expand_dims[v_ax0, v_ax1] = A[v_ax1]
    with T.block("exclusive_scan"):
        T.reads()
        T.writes()
        if seq_len * T.int64(8) == T.int64(0):
            blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
            if blockIdx_x < T.int64(1):
                T.evaluate(0)
        else:
            with T.launch_thread("threadIdx.x", T.int64(1024)) as threadIdx_x:
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (seq_len * T.int64(8) + T.int64(1023)) // T.int64(1024)))
                blockIdx_y = T.launch_thread("blockIdx.y", T.int64(1))
                if blockIdx_x * T.int64(1024) + threadIdx_x < seq_len * T.int64(8):
                    output_buf[(blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) % (seq_len * T.int64(8))] = T_expand_dims[(blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + (blockIdx_x * T.int64(1024) + threadIdx_x)) % (seq_len * T.int64(8))]
            for i in range(T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8)))))):
                threadIdx_x = T.launch_thread("threadIdx.x", 1024)
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * T.int64(8) + (T.int64(1024) * T.shift_left(T.int64(2), i) - T.int64(1))) // (T.int64(1024) * T.shift_left(T.int64(2), i)))))
                blockIdx_y = T.launch_thread("blockIdx.y", T.int64(1))
                start = T.allocate([T.int64(1)], "int64", "local")
                middle = T.allocate([T.int64(1)], "int64", "local")
                end = T.allocate([T.int64(1)], "int64", "local")
                start_1 = T.Buffer((1,), "int64", data=start, scope="local")
                start_1[T.int64(0)] = T.shift_left(T.int64(2), i) * T.Cast("int64", blockIdx_x * 1024 + threadIdx_x)
                if start_1[T.int64(0)] < seq_len * T.int64(8):
                    middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
                    middle_1[T.int64(0)] = start_1[T.int64(0)] + T.shift_left(T.int64(2), i) // T.int64(2)
                    end_1 = T.Buffer((1,), "int64", data=end, scope="local")
                    end_1[T.int64(0)] = T.min(start_1[T.int64(0)] + T.shift_left(T.int64(2), i), seq_len * T.int64(8))
                    if middle_1[T.int64(0)] < seq_len * T.int64(8):
                        output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] + output_buf[(blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))]
            with T.launch_thread("blockIdx.x", T.int64(1)) as blockIdx_x:
                if blockIdx_x < T.int64(1):
                    output_buf[((blockIdx_x + T.int64(1)) * (seq_len * T.int64(8)) - T.int64(1)) // (seq_len * T.int64(8)), ((blockIdx_x + T.int64(1)) * (seq_len * T.int64(8)) - T.int64(1)) % (seq_len * T.int64(8))] = 0
            for j in range(T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8)))))):
                threadIdx_x = T.launch_thread("threadIdx.x", 1024)
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * T.int64(8) + (T.int64(1024) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)) - T.int64(1))) // (T.int64(1024) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1))))))
                blockIdx_y = T.launch_thread("blockIdx.y", T.int64(1))
                start = T.allocate([T.int64(1)], "int64", "local")
                middle = T.allocate([T.int64(1)], "int64", "local")
                end = T.allocate([T.int64(1)], "int64", "local")
                end_1 = T.allocate([T.int64(1)], "int32", "local")
                start_1 = T.Buffer((1,), "int64", data=start, scope="local")
                start_1[T.int64(0)] = T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)) * T.Cast("int64", blockIdx_x * 1024 + threadIdx_x)
                if start_1[T.int64(0)] < seq_len * T.int64(8):
                    middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
                    middle_1[T.int64(0)] = start_1[T.int64(0)] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)) // T.int64(2)
                    end_2 = T.Buffer((1,), "int64", data=end, scope="local")
                    end_2[T.int64(0)] = T.min(start_1[T.int64(0)] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float64", seq_len * T.int64(8))))) - j - T.int64(1)), seq_len * T.int64(8))
                    if middle_1[T.int64(0)] < seq_len * T.int64(8):
                        end_3 = T.Buffer((1,), "int32", data=end_1, scope="local")
                        end_3[T.int64(0)] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))]
                        output_buf[(blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + middle_1[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))]
                        output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] = output_buf[(blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) // (seq_len * T.int64(8)), (blockIdx_y * (seq_len * T.int64(8)) + end_2[T.int64(0)] - T.int64(1)) % (seq_len * T.int64(8))] + end_3[T.int64(0)]
    for ax0 in range(seq_len * T.int64(8)):
        with T.block("T_squeeze"):
            v_ax0 = T.axis.spatial(seq_len * T.int64(8), ax0)
            T.reads(output_buf[T.int64(0), v_ax0])
            T.writes(T_squeeze[v_ax0])
            T_squeeze[v_ax0] = output_buf[T.int64(0), v_ax0]
    for ax0 in range(seq_len * T.int64(8)):
        with T.block("T_add"):
            v_ax0 = T.axis.spatial(seq_len * T.int64(8), ax0)
            T.reads(A[v_ax0], T_squeeze[v_ax0])
            T.writes(T_add[v_ax0])
            T_add[v_ax0] = A[v_ax0] + T_squeeze[v_ax0]

As you could see from the IR, threadIdx.x nests twice in this IR, and this is actually valid for codegen :((

@Lunderberg
Copy link
Contributor

Ah, I see. It looks like this originates from using sch.bind multiple times for the same thread index. With that usage, I'm less sure whether this should be correct behavior. When binding a thread index that isn't currently bound, the total number of loop iterations remains the same. When binding a thread index in a context where it is already bound, it constrains the two loop indices to always be the same.

# Before
for i in range(16):
    for j in range(16):
        B[i,j] = A[i,j]

# First thread binding, still have 16*16 = 256 total iterations.  B is
# set at all indices.
for i in T.thread_binding('threadIdx.x', 16):
    for j in range(16):
        B[i,j] = A[i,j]

# Second thread binding, now only have 16 iterations.  B is only set
# for indices where i==j, along the diagonal.
for i in T.thread_binding('threadIdx.x', 16):
    for j in T.thread_binding('threadIdx.x', 16):
        B[i,j] = A[i,j]

@junrushao
Copy link
Member Author

On CUDA, when fetching data from global memory to shared, there's something we use called cooperative fetching, where the data is divided evenly and all threads cooperatively fetch their own portions. To implement this logic in TVM, we have to bind threadIdx.x for an extra time:

  • Originally, it binds an outer loop to threadIdx.x as part of tiling process
  • It adds sch.cache_read + sch.compute_at that moves data copy under a loop in TIR
  • It partitions the data copy loop by the length of threadIdx.x and then do loop binding

It's the default schedule strategy used in MetaSchedule's multi-level tiling and here's an example of this usecase: https://github.com/apache/tvm/blob/main/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py#L71

@Lunderberg
Copy link
Contributor

Thank you on the details for cooperative fetching, and I agree that this check should be removed to make sure the use case is preserved, at least for the time being. I think so long as we have a unit test showing the functionality, this PR can be merged.

My worry wasn't with the overall sequence that produces cooperative fetching in the scheduled module, more that the individual step of producing a nested thread binding doesn't feel sound to me. The scheduling primitives shouldn't be able to impact the output values, only how the way in which the outputs are computed. Since there are definitely cases in which producing a nested thread binding changes the result, I was hoping there would be an alternative way to express cooperative fetching without requiring the nested thread binding.

@junrushao junrushao force-pushed the bugfix/2024-01-06/disable-single-env-verifier branch from aaae5a2 to 78ff3f9 Compare January 10, 2024 23:09
During TensorIR scheduling, the `IterVar`s that represent environment
threads may duplicate, i.e. it is legal to have two env threads with
the same name tag, which may fail the `SingleEnvThreadVerifier` check
during schedule creation. This PR disables this check in this case.
In the future, it may be worthwhile to bring it back against
post-scheduling TIR.
@junrushao junrushao force-pushed the bugfix/2024-01-06/disable-single-env-verifier branch from 78ff3f9 to 92be12a Compare January 11, 2024 00:12
@junrushao
Copy link
Member Author

junrushao commented Jan 11, 2024

The scheduling primitives shouldn't be able to impact the output values, only how the way in which the outputs are computed. Since there are definitely cases in which producing a nested thread binding changes the result, I was hoping there would be an alternative way to express cooperative fetching without requiring the nested thread binding.

@Lunderberg This is a good point! Actually when designing TensorIR, @spectrometerHBH has put lots of thoughts in it specifically for cooperative fetching, and he has a proof that nesting thread bindings won't change the semantics of a TensorIR during scheduling stage if it satisfies "compact dataflow" condition. Therefore, I wouldn't worry too much in this particular case.

In fact, in a simplified non-nested case below, the TIR describes two separate kernels, but when creating a TIR schedule, it errors out because blockIdx.x are different in different blocks. This is actually a valid case because the split-host-device pass later will split it into two kernels.

@T.prim_func
def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32):
    T.func_attr({"tir.noalias": T.bool(True)})
    A = T.match_buffer(var_A, (1, seq_len * 8), "int32")
    B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8)
    with T.block("exclusive_scan"):
        T.reads()
        T.writes()
        s8: T.int32 = seq_len * 8
        if s8 == 0:
            blockIdx_x = T.launch_thread("blockIdx.x", 1)
        else:
            with T.launch_thread("threadIdx.x", 1024) as threadIdx_x:
                blockIdx_x = T.launch_thread("blockIdx.x", T.ceildiv(s8, 1024))
                i: T.int32 = blockIdx_x * 1024 + threadIdx_x
                if i < s8:
                    B[i // s8, i % s8] = A[i // s8, i % s8]

I added this as a testcase

@vinx13
Copy link
Member

vinx13 commented Jan 11, 2024

There's regression on unity CI https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/unity/868/pipeline, we will need this fix

@junrushao
Copy link
Member Author

To quickly unblock the Unity CI, @vinx13 could you cherry-pick this commit to unity?

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, and thank you for adding the unit test!

@tqchen tqchen merged commit e2e33dd into apache:main Jan 11, 2024
@Lunderberg
Copy link
Contributor

@Lunderberg This is a good point! Actually when designing TensorIR, @spectrometerHBH has put lots of thoughts in it specifically for cooperative fetching, and he has a proof that nesting thread bindings won't change the semantics of a TensorIR during scheduling stage if it satisfies "compact dataflow" condition. Therefore, I wouldn't worry too much in this particular case.

@junrushao Ooh, nice! Is there a link to the proof? Asking because the test case below is a counter-example, where using sch.bind to introduce a nested thread binding does change the output of the PrimFunc. Between the proof, the implementation of CheckSubtreeCompactDataflow, and the lowering pipeline, one of the three must have an error, and the proof would help to narrow it down.

#!/usr/bin/env python3

import tvm.testing
from tvm.script import tir as T

import numpy as np


def _run_mod(mod: tvm.IRModule) -> np.array:
    target = "cuda"
    dev = tvm.device(target)

    np_A = np.ones([16, 16], "int32")
    np_B = np.zeros([16, 16], "int32")

    tvm_A = tvm.nd.array(np_A, dev)
    tvm_B = tvm.nd.array(np_B, dev)

    built = tvm.build(mod, target="cuda")
    built(tvm_A, tvm_B)

    # Result will have 1 for indices where the copy occurred, and 0
    # for indices where the copy failed to occur.
    return tvm_B.numpy()


def test_nested_thread_binding():
    @T.prim_func
    def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
        # Block index just makes sure this is a valid cuda kernel,
        # even before either i or j are bound to threadIdx.x.
        block_index = T.launch_thread("blockIdx.x", 1)

        for i, j in T.grid(16, 16):
            with T.block("copy"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj]

    sch = tvm.tir.Schedule(func)
    i, j = sch.get_loops("copy")

    res_0 = _run_mod(sch.mod)

    sch.bind(i, "threadIdx.x")
    res_1 = _run_mod(sch.mod)

    sch.bind(j, "threadIdx.x")
    res_2 = _run_mod(sch.mod)

    # This assert passes, all values are copied from A to B
    np.testing.assert_array_equal(res_0, res_1)

    # This assert fails, values are only copied from A to B along the
    # diagonal.
    np.testing.assert_array_equal(res_0, res_2)


if __name__ == "__main__":
    tvm.testing.main()

@spectrometerHBH
Copy link
Contributor

spectrometerHBH commented Jan 11, 2024

@Lunderberg

I think what the check needs to do is to make sure the loops that are mapped to block vars can not have 2 of them bound to the same thread axis.

The rationale is we have 2 iter spaces, one formulated by loops i0, i1, ..., and the other formulated by block vars v0, v1, ....

The block bindings describe how these 2 spaces are mapped and effectively the order of execution of block instances.
If 2 different block instances parameterized by block vars only differ in spatial axes, we consider them to be parallelizable. In the case of parallelization of block instances, we usually only consider 1-on-1 mappings between loops and block vars, which means different loop instances will launch different block instances.

If we have 2 separate loops bound to the same thread axis, then effectively the loop iter space of 2d I x J will be only 1dim I, which is wrong.

What you can do is to

  • Fuse them and bind the fused axis to the thread axis
  • Bind them to different axes
    which are both OK under the check.

TBH I don't think the checks (and the proofs I somehow crafted some time ago) or so can cover all the cases and are rigorous. But at least in the case above it's very clear.

@Lunderberg
Copy link
Contributor

I think what the check needs to do is to make sure the loops that are mapped to block vars can not have 2 of them bound to the same thread axis.

Hmm. So for the cooperative fetching example that @junrushao mentioned earlier (link), it would be acceptable because the ax0_ax1_fused_1 binding is used for A_shared and B_shared, while the i0_2_i1_2_fused binding is used for C and C_local, and no block has both in use at the same time?

How would that be checked after ConvertBlocksToOpaque? At that point, only the loop bindings remain, but we still have two loop bindings to the same thread index?

If we have 2 separate loops bound to the same thread axis, then effectively the loop iter space of 2d I x J will be only 1dim I, which is wrong.

Completely agreed. My concern is that, because schedule primitives may be part of automatic performance tuning, the resulting schedule could produce results that are faster, but incorrect.

My goal would be to have a way to throw an error for my invalid test case, without preventing the cooperative fetching case. Whether that would be easier by having a different sequence of primitives to arrive at cooperative fetching, or a new analysis pathway to forbid the sch.bind in my test case, I'm not sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants