Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,21 +372,22 @@ def forward(
if not output_attentions:
attn_weights = None

# Apply 2D sharding:
# activation (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding activations', attn_output.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))
if self.spmd_2d_sharding > 0:
# Apply 2D sharding:
# activation (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding activations', attn_output.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))

return attn_output, attn_weights, past_key_value

Expand Down Expand Up @@ -681,21 +682,22 @@ def forward(

# Is this the input to the model?
hidden_states = inputs_embeds
# Apply 2D sharding:
# input (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding hidden_states', hidden_states.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))
if self.spmd_2d_sharding > 0:
# Apply 2D sharding:
# input (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding hidden_states', hidden_states.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down