Skip to content

Conversation

@CharlieFRuan
Copy link
Member

@CharlieFRuan CharlieFRuan commented Mar 3, 2025

Overview

This PR supports warp-level shuffle primitives using the newly introduced subgroup in WebGPU. We then use them in the implementation of allreduce lowering.

The introduced primitives are:

  • subgroupShuffle()
  • subgroupShuffleUp()
  • subgroupShuffleDown()

This PR largely follows the Metal counterpart:

Tested with Llama3.2-1B-q4f16_1 and Llama3.1-8B-q4f16_1 E2E with WebLLM. The dumped WebGPU kernel indeed contains subgroup shuffle primitives: https://gist.github.com/CharlieFRuan/cb54a8db0513ecbbc16c5de8df5ab845

Remaining TODOs

  • Benchmark speedup
  • Be able to parameterize whether to use subgroup or not when targeting WebGPU, since not all devices support it
  • Check GPUFeatureName's inclusion of subgroups in @webgpu/types
  • Some WebGPU devices can have > 256 max num thread per block, be able to target different kinds

Resources

}

const requiredFeatures: GPUFeatureName[] = [];
// TODO(Charlie): cannot type annotate because @webgpu/types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@webgpu/types 0.1.55 should work now. See gpuweb/types#167

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants