Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

@LeiWang1999 LeiWang1999 commented Jan 3, 2024

#759 proposed a pass storage_rewrite and provided a trivial storage reuse plan based on liveness analysis, just as #9341 mentioned, the solution has some limitations:

  1. storage_rewrite can't handle buffer with different dtypes.
    int8 A_shared[32];
    int8 B_shared[32];
    int32 C_shared[4]; // will not be reused even we have enough workspace as different types. 
  2. storage_rewrite can't allocate a buffer in the place of another 2 buffers.
       int8 A_shared[32];
       int8 B_shared[32];
       int8 C_shared[64]; 
      // will be reused as A_shared[32], B_shared[64],  results in 32 half elements space waste.

#8571 and #9341 introduced a pass MergeDynamicSharedMemoryAllocations, which can support efficient memory reuse solely for dynamic shared memory. However, sometimes we do not want to use dynamic shared memory for codegen, so this pull request made a simple extend to MergeDynamicSharedMemoryAllocations to support both dynamic and static shared memory optimal reuse.

By default, the static shared memory merge is disabled to maintain consistency, to enable the static part:

with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
    cuda_mod = tvm.build(sch.mod, target="cuda")

Take int8xint8=int32 tensorcore gemm as an example, we have a big tile and used static shared memory, before the pass:

__global__ void __launch_bounds__(128) Fused(int8_t* __restrict__ input0, int8_t* __restrict__ input1, int* __restrict__ output0) {
  
  int mediate0_shared_warp[128];
  __shared__ signed char input0_shared[16384];
  __shared__ signed char input1_shared[16384];
  signed char input0_shared_warp[64];
  signed char input1_shared_warp[64];
  signed char input0_shared_warp_1[64];
  signed char input1_shared_warp_1[64];
  __shared__ int mediate0_shared[6400];

it will exceed the maximum available static shared memory, and compilation will fail. After this pass

__global__ void __launch_bounds__(128) Fused(int8_t* __restrict__ input0, int8_t* __restrict__ input1, int* __restrict__ output0) {
  
  __shared__ uchar buf_shmem[32768];
  int mediate0_shared_warp[128];
  signed char input0_shared_warp[64];
  signed char input1_shared_warp[64];
  signed char input0_shared_warp_1[64];
  signed char input1_shared_warp_1[64];

we can save around 50% shared memory and the compilation can pass, code generation perf with fastdlight can achieve 510+Tflops (without the pass, the best tile is around 420TFlops on A100), this pass will enable us to explore more tile configs under static shared memory.

Moreover, the pass can optimize the dynamic shared memory plan as well, as the storage_rewrite pass will merge C_shared to B_shared in this example, which is not friendly for further memory plan analysis, the flag merge_static_smem will disable the trivial reuse behavior by (don't know if the flag can be improved):

if (!enable_reuse || is_small_array || !is_flat_memory_space) {
  return NewAlloc(op, attach_scope, scope, const_nbits);
}

@junrushao
Copy link
Member

This is really amazing addition! Particularly, I found it painful that existing storage-rewrite pass doesn't handle heterogeneous dtypes (which CUDA does support casting in-between), and also on the second point, yes it's limited by current rewriting - thanks for the contribution!


def run_passes(sch, args):
mod = schedule_to_module(sch, args)
with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case should we turn this flag off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tir.merge_static_smem is set to False by default to ensure the code is more readable. (for example, maintain clearly definition of A_shared, B_shared, instead of (half*)(buf_shmem+ offset)) , so it should be manually enabled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it - so it's turned off basically for better readability, is my understanding correct?

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM!

@vinx13 vinx13 merged commit e3216a6 into apache:unity Jan 5, 2024
@jinhongyii
Copy link
Contributor

It seems that this PR is not merged in squash mode.

@vinx13
Copy link
Member

vinx13 commented Jan 5, 2024

oops sorry this is a mistake

@masahi
Copy link
Member

masahi commented Jan 5, 2024

This PR should have been sent to main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants