Skip to content

Commit 2fe0200

Browse files
alanwaketanyeounoh
authored andcommitted
Enable 2D sharding (#17)
Summary: This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy. Let's say we have a 2D mesh (data, model) and data x model == num_devices: 1. input (data,, None, model) 2. embedding (model, data) 3. attn QKV (data, model) 4. attn O (model, data) 5. mlp gate, up (model, data) 6. mlp down (data, model) 7. activation (data,, None, model) Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated. TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
1 parent 1b9c7d8 commit 2fe0200

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ class ModelArguments:
189189
)
190190
},
191191
)
192+
spmd_2d_sharding: int = field(
193+
default=0,
194+
metadata={
195+
"help": (
196+
"Will apply XLA SPMD to 2D sharding, i.e., weights + activations, and spmd_2d_sharding specifies the model dimension"
197+
)
198+
},
199+
)
192200

193201
def __post_init__(self):
194202
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
@@ -468,6 +476,8 @@ def main():
468476
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
469477
)
470478

479+
# Pass the 2d sharding config to the actual model.
480+
config.spmd_2d_sharding = model_args.spmd_2d_sharding
471481
if model_args.model_name_or_path:
472482
torch_dtype = (
473483
model_args.torch_dtype
@@ -538,6 +548,42 @@ def main():
538548
else:
539549
assert len(param.shape) == 2
540550
xs.mark_sharding(param, mesh, range(len(param.shape)))
551+
elif model_args.spmd_2d_sharding > 0:
552+
print('Applying 2D sharding to all parameters')
553+
for name, param in model.named_parameters():
554+
# Apply 2D sharding:
555+
# embedding (model, data)
556+
# attn QKV (data, model)
557+
# attn O (model, data)
558+
# mlp gate, up (model, data)
559+
# mlp down (data, model)
560+
print('> Sharding tensor', name, param.shape)
561+
mod = model_args.spmd_2d_sharding
562+
data = num_devices // mod
563+
assert mod * data == num_devices
564+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, mod))
565+
model_data_mesh = xs.HybridMesh(ici_mesh_shape=(mod, data))
566+
567+
# We don't care about layernorm's weights, and
568+
# LLaMA doesn't use biases.
569+
if len(param.shape) == 1:
570+
continue
571+
572+
if 'embed_tokens' in name:
573+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
574+
elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
575+
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
576+
elif 'o_proj' in name:
577+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
578+
elif 'gate_proj' in name or 'up_proj' in name:
579+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
580+
elif 'down_proj' in name:
581+
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
582+
elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens
583+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
584+
585+
import torch_xla
586+
print(torch_xla._XLAC._get_xla_sharding_spec(param))
541587

542588
# Preprocessing the datasets.
543589
# First we tokenize all the texts.

src/transformers/models/llama/modeling_llama.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,22 @@ def forward(
401401
if not output_attentions:
402402
attn_weights = None
403403

404+
# Apply 2D sharding:
405+
# activation (data,, None, model)
406+
import torch_xla.core.xla_model as xm
407+
import torch_xla.experimental.xla_sharding as xs
408+
import torch_xla.runtime as xr
409+
import torch_xla
410+
num_devices = xr.global_runtime_device_count()
411+
device_ids = torch.arange(num_devices)
412+
print('> Sharding activations', attn_output.shape)
413+
model = self.spmd_2d_sharding
414+
data = num_devices // model
415+
assert model * data == num_devices
416+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
417+
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
418+
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))
419+
404420
return attn_output, attn_weights, past_key_value
405421

406422

@@ -920,6 +936,9 @@ class LlamaModel(LlamaPreTrainedModel):
920936

921937
def __init__(self, config: LlamaConfig):
922938
super().__init__(config)
939+
# For PyTorch/XLA's SPMD 2D sharding
940+
self.spmd_2d_sharding = config.spmd_2d_sharding
941+
923942
self.padding_idx = config.pad_token_id
924943
self.vocab_size = config.vocab_size
925944

src/transformers/trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,10 +1548,11 @@ def _xla_sharded_dataloader(self, dataloader):
15481548
if self.args.spmd_batch_sharding:
15491549
mesh = xs.Mesh(device_ids, (num_devices, 1))
15501550
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
1551-
elif self.args.spmd_tensor_sharding > 0:
1552-
tensor = self.args.spmd_tensor_sharding
1551+
elif self.args.spmd_tensor_sharding > 0 or self.args.spmd_2d_sharding > 0:
1552+
assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0
1553+
tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding
15531554
fsdp = num_devices // tensor
1554-
mesh = xs.Mesh(device_ids, (fsdp, tensor))
1555+
mesh = xs.HybridMesh(ici_mesh_shape=(fsdp, tensor))
15551556
partition_spec = (0, None)
15561557
sharding_spec = xs.ShardingSpec(mesh, partition_spec)
15571558

0 commit comments

Comments
 (0)