Skip to content

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Aug 1, 2023

Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:

  1. input (data,, None, model)
  2. embedding (model, data)
  3. attn QKV (data, model)
  4. attn O (model, data)
  5. mlp gate, up (model, data)
  6. mlp down (data, model)
  7. activation (data,, None, model)

Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.

Copy link

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice one Jiewen!

Comment on lines 540 to 541
data_model_mesh = xs.Mesh(device_ids, (data, mod))
model_data_mesh = xs.Mesh(device_ids, (mod, data))
Copy link

@jonb377 jonb377 Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try with HybridMesh? It should provide some performance gain, but I haven't actually benchmarked the difference. Here and in modeling_llama.py

@khatwanimohit may have some benchmarked differences on the simple shardings.py script

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do that. Always forgot.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed that too.

Comment on lines 554 to 557
elif 'gate_proj' in name or 'up_proj' in name:
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
elif 'down_proj' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding: I noticed that HF shards gate_proj and up_proj on the 0th dim and down_proj on the 1st dim, but here you're sharding gate and up on the data_model mesh, which places the model axis on dim 1.

Is this just a difference in 1- and 2-D sharding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch. I don't know. Let me dig into it. I'm following the slides attached on the top of the spreadsheet.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! I was just curious, using the sharding from the slides makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, you are right. I have corrected the error.

@alanwaketan
Copy link
Collaborator Author

Thanks Jon for approving the pull request.

@alanwaketan alanwaketan merged commit 813af25 into llama2-google-next-training Aug 1, 2023
alanwaketan added a commit that referenced this pull request Aug 2, 2023
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
alanwaketan added a commit that referenced this pull request Oct 27, 2023
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
alanwaketan added a commit that referenced this pull request Oct 27, 2023
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
yeounoh pushed a commit that referenced this pull request Mar 19, 2024
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
yeounoh pushed a commit that referenced this pull request Mar 19, 2024
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
vanbasten23 pushed a commit that referenced this pull request May 21, 2024
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
vanbasten23 pushed a commit that referenced this pull request May 21, 2024
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants