-
Notifications
You must be signed in to change notification settings - Fork 564
[EP] add initial support for NVSHMEM-based all-to-all #1569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,12 +166,15 @@ def __init__(self): | |
super().__init__() | ||
self.input_splits = None | ||
self.output_splits = None | ||
self.input_shape = None | ||
self.permuted_indices = None | ||
|
||
# performing all-to-all dispatch on the input | ||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
# annotate module input placements/sharding with input_layouts | ||
routed_input, num_tokens_per_expert = inputs | ||
ep_size = device_mesh.shape[0] | ||
ep_degree = device_mesh.shape[0] | ||
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree | ||
|
||
# generate the input splits and output splits for all-to-all | ||
with torch.no_grad(): | ||
|
@@ -184,12 +187,12 @@ def _token_dispatch(self, mod, inputs, device_mesh): | |
group=device_mesh.get_group(), | ||
) | ||
input_splits = ( | ||
num_tokens_per_expert.view(ep_size, -1) | ||
num_tokens_per_expert.view(ep_degree, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
) | ||
output_splits = ( | ||
num_tokens_per_expert_group.view(ep_size, -1) | ||
num_tokens_per_expert_group.view(ep_degree, -1) | ||
.sum(dim=1) | ||
.to(torch.device("cpu"), non_blocking=True) | ||
) | ||
|
@@ -212,11 +215,21 @@ def _token_dispatch(self, mod, inputs, device_mesh): | |
# Rather, it is of the format | ||
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., | ||
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] | ||
# We need to perform another shuffle to get the correct format -- this is done via the function | ||
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens | ||
# each expert gets locally is a multiple of ALIGN_SIZE_M. | ||
# We need to perform another shuffle to get the correct layout, via the _permute function | ||
# below, which also does padding to make sure the number of tokens each expert gets locally | ||
# is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. | ||
|
||
return routed_input, num_tokens_per_expert_group | ||
( | ||
self.input_shape, | ||
routed_input, | ||
self.permuted_indices, | ||
num_tokens_per_expert_group, | ||
offsets, | ||
) = _permute( | ||
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts | ||
) | ||
|
||
return routed_input, num_tokens_per_expert_group, offsets | ||
|
||
@staticmethod | ||
def _partition_fn(name, mod, device_mesh): | ||
|
@@ -227,6 +240,10 @@ def _partition_fn(name, mod, device_mesh): | |
|
||
# performing all-to-all combine on the output | ||
def _token_combine(self, mod, routed_output, device_mesh): | ||
routed_output = _unpermute( | ||
routed_output, self.input_shape, self.permuted_indices | ||
) | ||
|
||
routed_output = all_to_all_single_autograd( | ||
routed_output, | ||
self.input_splits, | ||
|
@@ -247,20 +264,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
|
||
# This class is for dp2ep with TP (without TP we can just use ExpertParallel) | ||
class ExpertTensorParallel(ExpertParallel): | ||
def __init__( | ||
self, | ||
tp_mesh: DeviceMesh, | ||
ep_mesh: DeviceMesh, | ||
): | ||
super().__init__() | ||
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, | ||
# as DeviceMesh doesn't support slicing from a submesh. | ||
self.tp_mesh = tp_mesh | ||
self.ep_mesh = ep_mesh | ||
|
||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_dispatch(mod, inputs, self.ep_mesh) | ||
return super()._token_dispatch(mod, inputs, device_mesh["ep"]) | ||
|
||
def _partition_fn_2d(self, name, mod, ep_tp_mesh): | ||
# w1 shape = (experts, out_dim, in_dim) | ||
|
@@ -283,7 +289,7 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh): | |
|
||
def _token_combine(self, mod, routed_output, device_mesh): | ||
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_combine(mod, routed_output, self.ep_mesh) | ||
return super()._token_combine(mod, routed_output, device_mesh["ep"]) | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
|
@@ -295,25 +301,42 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
) | ||
|
||
|
||
def expert_parallel(func: Callable) -> Callable: | ||
def _permute(x, num_tokens_per_expert, ep_degree, num_local_experts): | ||
# TODO: move to core | ||
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices | ||
|
||
global TOKEN_GROUP_ALIGN_SIZE_M | ||
with torch.no_grad(): | ||
(permuted_indices, num_tokens_per_expert, offsets,) = generate_permute_indices( | ||
num_tokens_per_expert, | ||
num_local_experts, | ||
ep_degree, | ||
x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M, | ||
TOKEN_GROUP_ALIGN_SIZE_M, | ||
) | ||
|
||
x = torch.vstack((x, x.new_zeros((x.shape[-1])))) | ||
input_shape = x.shape | ||
x = x[permuted_indices, :] | ||
|
||
return input_shape, x, permuted_indices, num_tokens_per_expert, offsets | ||
|
||
|
||
def _unpermute(out, input_shape, permuted_indices): | ||
out_unpermuted = out.new_empty(input_shape) | ||
out_unpermuted[permuted_indices, :] = out | ||
out = out_unpermuted[:-1] | ||
return out | ||
|
||
|
||
def indices_permutation_wrapper(func: Callable) -> Callable: | ||
""" | ||
This is a wrapper applied to the GroupedExperts computation, serving | ||
the following three purposes: | ||
1. Convert parameters from DTensors to plain Tensors, to work with | ||
dynamic-shape inputs which cannot be easily expressed as DTensors. | ||
2. In Expert Parallel, apply the generate_permute_indices kernel to | ||
permute the inputs to be ordered by local experts (see the _token_dispatch | ||
function in ExpertParallel) and permute the outputs back. | ||
3. In order to use torch._grouped_mm, we need to make sure the number of | ||
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices | ||
kernel also helps achieve this via padding, without incurring synchronization | ||
between device and host. Note that this will create side effects when wrapping | ||
the for-loop implementation of GroupedExperts, as it does not need padding. | ||
|
||
Among the above: | ||
1 and 2 are needed only when expert_parallel_degree > 1. | ||
3 is needed even for single-device computation. | ||
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. | ||
In order to use torch._grouped_mm, we need to make sure the number of | ||
tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The | ||
generate_permute_indices kernel also helps achieve this via padding, | ||
without incurring synchronization between device and host. Note that | ||
this will create side effects when wrapping the for-loop implementation | ||
of GroupedExperts, as it does not need padding. | ||
""" | ||
|
||
def wrapper( | ||
|
@@ -322,40 +345,18 @@ def wrapper( | |
w3: torch.Tensor, | ||
x: torch.Tensor, | ||
num_tokens_per_expert: torch.Tensor, | ||
_offsets: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
global TOKEN_GROUP_ALIGN_SIZE_M | ||
if isinstance(w1, DTensor): | ||
w1 = w1.to_local() | ||
w2 = w2.to_local() | ||
w3 = w3.to_local() | ||
num_local_experts = w1.shape[0] | ||
ep_degree = num_tokens_per_expert.shape[0] // num_local_experts | ||
|
||
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices | ||
|
||
experts_per_ep_rank = w1.shape[0] | ||
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank | ||
|
||
with torch.no_grad(): | ||
( | ||
permuted_indices, | ||
num_tokens_per_expert, | ||
_, # offsets, | ||
) = generate_permute_indices( | ||
num_tokens_per_expert, | ||
experts_per_ep_rank, | ||
num_ep_ranks, | ||
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, | ||
TOKEN_GROUP_ALIGN_SIZE_M, | ||
) | ||
|
||
x = torch.vstack((x, x.new_zeros((x.shape[-1])))) | ||
input_shape = x.shape | ||
x = x[permuted_indices, :] | ||
input_shape, x, permuted_indices, num_tokens_per_expert, offsets = _permute( | ||
x, num_tokens_per_expert, ep_degree, num_local_experts | ||
) | ||
|
||
out = func(w1, w2, w3, x, num_tokens_per_expert) | ||
out = func(w1, w2, w3, x, num_tokens_per_expert, offsets) | ||
|
||
out_unpermuted = out.new_empty(input_shape) | ||
out_unpermuted[permuted_indices, :] = out | ||
out = out_unpermuted[:-1] | ||
out = _unpermute(out, input_shape, permuted_indices) | ||
|
||
return out | ||
|
||
|
@@ -373,7 +374,7 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh): | |
selected_experts_indices, device_mesh, (Replicate(),) | ||
) | ||
|
||
# TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree | ||
# NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree | ||
# if top_scores.shape[0] % device_mesh.size() != 0: | ||
# num_tokens = top_scores.shape[0] | ||
# tp_size = device_mesh.size() | ||
|
@@ -409,3 +410,145 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
input_fn=self._prepare_inputput_fn, | ||
output_fn=self._prepare_output_fn, | ||
) | ||
|
||
|
||
# TODO: let multiple MoE layers share the same input / output buffer | ||
# TODO: add NVSHMEMExpertTensorParallel support | ||
class NVSHMEMExpertParallel(ParallelStyle): | ||
def __init__( | ||
self, | ||
num_tokens: int, | ||
dim: int, | ||
num_experts: int, | ||
ep_mesh: DeviceMesh, | ||
dtype: torch.dtype, | ||
): | ||
import torch.distributed._symmetric_memory as symm_mem | ||
|
||
from torchtitan.experiments.kernels.moe.combine import TokenCombiner | ||
from torchtitan.experiments.kernels.moe.dispatch import TokenDispatcher | ||
from torchtitan.tools.utils import device_type | ||
|
||
super().__init__() | ||
|
||
ep_degree = ep_mesh.shape[0] | ||
# bs * slen * top_k, or | ||
# (bs * slen // tp_degree) * top_k if ReordererSequenceParallel is used | ||
input_length = num_tokens | ||
# TODO: make overflow_factor configurable? but can cause IMA | ||
# worst case: one rank receives all data | ||
overflow_factor = ep_degree | ||
max_output_length = input_length * overflow_factor | ||
|
||
device = torch.device(device_type) | ||
self.input_buffer = symm_mem.empty( | ||
num_tokens, | ||
dim, | ||
dtype=dtype, | ||
device=device, | ||
) | ||
self.output_buffer = symm_mem.empty( | ||
max_output_length, dim, dtype=dtype, device=device | ||
) | ||
# two rows: input splits, input offsets | ||
self.in_splits_offsets_buffer = symm_mem.empty( | ||
(2, num_experts), dtype=torch.int64, device=device | ||
) | ||
# two rows: output splits, output offsets | ||
self.out_splits_offsets_buffer = symm_mem.empty( | ||
(2, num_experts), dtype=torch.int64, device=device | ||
) | ||
|
||
group_name = ep_mesh.get_group().group_name | ||
num_local_experts = num_experts // ep_degree | ||
|
||
global TOKEN_GROUP_ALIGN_SIZE_M | ||
self.dispatcher = TokenDispatcher( | ||
group_name, | ||
TOKEN_GROUP_ALIGN_SIZE_M, | ||
input_length, | ||
max_output_length, | ||
[dim], | ||
ep_degree, | ||
num_local_experts, | ||
dtype, | ||
device, | ||
) | ||
self.combiner = TokenCombiner( | ||
group_name, | ||
TOKEN_GROUP_ALIGN_SIZE_M, | ||
max_output_length, | ||
input_length, | ||
[dim], | ||
ep_degree, | ||
num_local_experts, | ||
dtype, | ||
device, | ||
) | ||
|
||
self.input_splits = None | ||
self.output_splits = None | ||
|
||
# performing all-to-all dispatch on the input | ||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this new implementation will get rid of the need of |
||
# annotate module input placements/sharding with input_layouts | ||
routed_input, num_tokens_per_expert = inputs | ||
ep_degree = device_mesh.shape[0] | ||
|
||
self.input_splits = num_tokens_per_expert | ||
self.in_splits_offsets_buffer[0].copy_(self.input_splits) | ||
input_buffer = self.input_buffer.detach() | ||
output_buffer = self.output_buffer.detach() | ||
input_buffer.copy_(routed_input) | ||
output = self.dispatcher( | ||
input_buffer, | ||
output_buffer, | ||
self.in_splits_offsets_buffer[0], | ||
self.out_splits_offsets_buffer, | ||
) | ||
|
||
# NOTE: output_splits layout: | ||
# for i in range(num_local_experts): | ||
# for j in range(ep_degree): | ||
# output_splits[i * ep_degree + j] denotes: | ||
# number of tokens passed from EP rank j to local expert i | ||
output_splits = self.out_splits_offsets_buffer[0] | ||
output_offsets = self.out_splits_offsets_buffer[1] | ||
|
||
# TODO: need to simplify this | ||
offsets = torch.zeros_like(output_offsets[::ep_degree]) | ||
offsets[:-1] = output_offsets[ep_degree::ep_degree] | ||
offsets[-1] = output_offsets[-1] + output_splits[-1] | ||
|
||
return output, None, offsets.to(dtype=torch.int32) | ||
|
||
@staticmethod | ||
def _partition_fn(name, mod, device_mesh): | ||
# shard on the expert dimension | ||
for name, param in mod.named_parameters(recurse=False): | ||
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) | ||
mod.register_parameter(name, dist_param) | ||
|
||
# performing all-to-all combine on the output | ||
def _token_combine(self, mod, routed_output, device_mesh): | ||
input_buffer = self.input_buffer.detach() | ||
output_buffer = self.output_buffer.detach() | ||
output_buffer.copy_(routed_output) | ||
|
||
routed_output = self.combiner( | ||
output_buffer, | ||
input_buffer, | ||
self.out_splits_offsets_buffer, | ||
self.in_splits_offsets_buffer, | ||
) | ||
|
||
return routed_output | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
partition_fn=ExpertParallel._partition_fn, | ||
input_fn=self._token_dispatch, | ||
output_fn=self._token_combine, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,7 +88,9 @@ def forward( # type: ignore[no-untyped-def] | |
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets | ||
) | ||
ctx.group_name = group_name | ||
return out | ||
|
||
# TODO: why do we need this clone? | ||
return out.clone() | ||
Comment on lines
+92
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you try removing this clone after we added out_buffer.detach() ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still erroring out if removing this clone
|
||
|
||
@staticmethod | ||
def backward( # type: ignore[no-untyped-def] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make it reusable with GPT-oss implementation, can we make w1, w2, w3 somehow to a list of parameters or kwargs (basically take variable number of weights and bias)? I think what we do in this wrapper is just taking these inputs and then pass it further to func().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing which gpt-oss is not reusable is the ExpertTensorParallel(). I guess for this part, if a model has different mathematical formula and variable number of weights/bias, it's user's responsible to update
_partition_fn_2d
in ETP, wdyt?