Skip to content

Conversation

@yzh119
Copy link
Member

@yzh119 yzh119 commented Feb 10, 2022

Previously the LowerThreadAllReduce pass will only emit code that uses shfl_down when reduce extent equals warp size, when reduce extent is less than warp size, the codegen fall back to emit code that uses shared memory, which is not efficient. Considering CUDA supports sub-warp reduction by specifying the mask, we can still use the shuffle-down approach for reduction by changing the mask.

Example code:

import tvm
import numpy as np
from tvm.script import tir as T


@T.prim_func
def reduce(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [1024, 11])
    B = T.match_buffer(b, [1024])

    for i, j in T.grid(1024, 11):
        with T.block("reduce"):
            vi, vj = T.axis.remap("SR", [i, j])
            with T.init():
                B[vi] = 0.
            B[vi] = B[vi] + A[vi, vj]

sch = tvm.tir.Schedule(reduce)
blk = sch.get_block("reduce")
i, j = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")
print(f.imported_modules[0].get_source())

Emitted code before this PR:

extern "C" __global__ void __launch_bounds__(11) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
  __shared__ float red_buf0[11];
  __syncthreads();
  ((volatile float*)red_buf0)[(((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 11) + ((int)threadIdx.x)))];
  __syncthreads();
  if (((int)threadIdx.x) < 3) {
    ((volatile float*)red_buf0)[(((int)threadIdx.x))] = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 8))]);
  }
  __syncthreads();
  if (((int)threadIdx.x) < 4) {
    float w_4_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 4))]);
    ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_4_0;
    float w_2_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 2))]);
    ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_2_0;
    float w_1_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 1))]);
    ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_1_0;
  }
  __syncthreads();
  B[(((int)blockIdx.x))] = ((volatile float*)red_buf0)[(0)];
}

Emitted code after this PR:

extern "C" __global__ void __launch_bounds__(11) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
  float red_buf0[1];
  uint mask[1];
  float t0[1];
  red_buf0[(0)] = A[(((((int)blockIdx.x) * 11) + ((int)threadIdx.x)))];
  mask[(0)] = __activemask();
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32);
  B[(((int)blockIdx.x))] = red_buf0[(0)];
}

Future work

CUDA 11 supports cooperative group reduction which we can directly use.

cc @vinx13 @junrushao1994

@Hzfengsy
Copy link
Member

Do you have any performance results? Also please add testcases

@yzh119
Copy link
Member Author

yzh119 commented Feb 10, 2022

Sure, below is the measured time of the kernel:

@T.prim_func
def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
    A = T.match_buffer(a, [1048576, n])
    B = T.match_buffer(b, [1048576])

    for i, j in T.grid(1048576, n):
        with T.block("reduce"):
            vi, vj = T.axis.remap("SR", [i, j])
            with T.init():
                B[vi] = 0.
            B[vi] = B[vi] + A[vi, vj]

and change n between 2,4,8,16,32.

n 2 4 8 16 32
shared-mem time(ms) 0.836363387 0.902631863 1.214023657 1.249731274 1.175273217
shuffle-down time(ms) 0.80920489 0.9997110469999999 1.076497658 1.103504739 1.1167795269999998

there is some variance across multiple runs. Time evaluated with TVM's native time_evaluator, takes the average time of 1000 runs.

@junrushao
Copy link
Member

CC @MasterJH5574 I believe you are interested

@yzh119
Copy link
Member Author

yzh119 commented Feb 10, 2022

Some other notes:

If in the following case:

@T.prim_func
def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
    A = T.match_buffer(a, [1, 4, 8])
    B = T.match_buffer(b, [1, 4])

    for i, j, k in T.grid(1, 4, 8):
        with T.block("reduce"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                B[vi, vj] = 0.
            B[vi, vj] = B[vi, vj] + A[vi, vj, vk]

we bind j to threadIdx.y and k to threadIdx.x, different j's might be mapped to the same warp, we need different masks for different j to distinguish them.

Below is an example of generated code:

extern "C" __global__ void __launch_bounds__(32) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
  float red_buf0[1];
  uint mask[1];
  float t0[1];
  red_buf0[(0)] = A[(((((int)threadIdx.y) * 8) + ((int)threadIdx.x)))];
  mask[(0)] = (__activemask() & ((uint)(255 << (((int)threadIdx.y) * 8))));
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
  red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
  red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], (((int)threadIdx.y) * 8), 32);
  B[(((int)threadIdx.y))] = red_buf0[(0)];
}

Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size must be a multiple of blockDim.x when blockDim.y * blockDim.z != 1.

@yzh119
Copy link
Member Author

yzh119 commented Feb 10, 2022

@Hzfengsy I write a unit test and find a bug (#10210 ) in original shared memory-based tree reduction, it was fixed in this PR.

@MasterJH5574
Copy link
Contributor

Interesting. Looks like the perf improvement isn't very much? Only when n = 4 the shuffle-down implementation is better than the shared memory implementation 🤔

Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size must be a multiple of blockDim.x when blockDim.y * blockDim.z != 1.

BTW do we have this requirement in the codebase now?

@yzh119
Copy link
Member Author

yzh119 commented Feb 11, 2022

Looks like the perf improvement isn't very much? Only when n = 4 the shuffle-down implementation is better than the shared memory implementation 🤔

My typo, I have fixed it.

Another benefit of using shuffle-down is reducing the shared memory usage thus increasing the number of blocks can be executed concurrently.

@yzh119
Copy link
Member Author

yzh119 commented Feb 11, 2022

BTW do we have this requirement in the codebase now?

@MasterJH5574 yes there is a notion of group_extent and reduce_extent.

@junrushao
Copy link
Member

will leave the PR to @vinx13 and @masahi for a second look :-)

@vinx13 vinx13 merged commit e13110f into apache:main Feb 15, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* upd

* upd

* upd

* lint

* fix

* upd docstring

* upd
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
* upd

* upd

* upd

* lint

* fix

* upd docstring

* upd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants