Skip to content

Conversation

HeyangQin
Copy link
Contributor

@HeyangQin HeyangQin commented Jan 6, 2024

Previously we use a series of forward/backward flags to control if hpz should be enabled on certain allgather call. This PR simplifies this by enabling hpz only when its secondary tensor exists (and invalidating its secondary tensor whenever master weights changes). This should:

  1. Prevent potential out-of-sync issue compared with our currently way of overwriting secondary tensor
  2. Improve throughput because now hpz will be enabled in a lot of different scenarios including i) activation checkpointing, ii) gradient accumulation, iii)torch.no_grad context, iv) model.eval() mode, v)LoRA frozen weights, vi) gradient overflow

This is to fix #4851

Convergence test:

zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl: 5.151907920837402, loss: 1.6393671035766602
hpz with this PR: ppl: 5.081737518310547, loss: 1.6256532669067383

zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl: 1.8326854705810547, loss: 0.6057823896408081
hpz with this PR: ppl: 1.8326854705810547, loss: 0.6057823896408081

Performance test on 32 V100, still using https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.

  • gradient accumulation step = 8

master branch with hpz: SamplesPerSec=17.567813158654847
this patch with hpz: SamplesPerSec=24.121657876029225

  • lora

master branch with hpz: SamplesPerSec=33.88883430864484
this patch with hpz: SamplesPerSec=43.39463460004735

@tjruwase tjruwase requested a review from samadejacobs January 6, 2024 03:44
@HeyangQin HeyangQin marked this pull request as ready for review January 8, 2024 16:49
Copy link
Contributor

@samadejacobs samadejacobs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, good job @HeyangQin

@HeyangQin HeyangQin requested a review from loadams as a code owner January 14, 2024 14:07
@mrwyattii
Copy link
Contributor

Manually running nightly tests here: https://github.com/microsoft/DeepSpeed/actions/runs/7658819103

@mrwyattii mrwyattii enabled auto-merge January 25, 2024 19:04
@mrwyattii mrwyattii added this pull request to the merge queue Jan 25, 2024
Merged via the queue into master with commit 75ed63c Jan 25, 2024
@mrwyattii mrwyattii deleted the HeyangQin/mixz_hpz_fix branch January 25, 2024 22:53
@siddharth9820
Copy link
Contributor

Is hpz safe to use now?

mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
Previously we use a series of forward/backward flags to control if hpz
should be enabled on certain allgather call. This PR simplifies this by
enabling hpz only when its secondary tensor exists (and invalidating its
secondary tensor whenever master weights changes). This should:
1. Prevent potential out-of-sync issue compared with our currently way
of overwriting secondary tensor
2. Improve throughput because now hpz will be enabled in a lot of
different scenarios including i) activation checkpointing, ii) gradient
accumulation, iii)`torch.no_grad` context, iv) `model.eval()` mode,
v)LoRA frozen weights, vi) gradient overflow

This is to fix deepspeedai#4851

Convergence test:

- llama-2-7b random weights, using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.

> zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl:
5.151907920837402, loss: 1.6393671035766602
> hpz with this PR: ppl: 5.081737518310547, loss: 1.6256532669067383

- llama-2-7b pretrained weights with lora, using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh.

> zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl:
1.8326854705810547, loss: 0.6057823896408081
> hpz with this PR: ppl: 1.8326854705810547, loss: 0.6057823896408081

Performance test on 32 V100, still using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.
- gradient accumulation step = 8
> master branch with hpz: SamplesPerSec=17.567813158654847
> this patch with hpz: SamplesPerSec=24.121657876029225
- lora
> master branch with hpz: SamplesPerSec=33.88883430864484
> this patch with hpz: SamplesPerSec=43.39463460004735

---------

Co-authored-by: Michael Wyatt <[email protected]>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
Previously we use a series of forward/backward flags to control if hpz
should be enabled on certain allgather call. This PR simplifies this by
enabling hpz only when its secondary tensor exists (and invalidating its
secondary tensor whenever master weights changes). This should:
1. Prevent potential out-of-sync issue compared with our currently way
of overwriting secondary tensor
2. Improve throughput because now hpz will be enabled in a lot of
different scenarios including i) activation checkpointing, ii) gradient
accumulation, iii)`torch.no_grad` context, iv) `model.eval()` mode,
v)LoRA frozen weights, vi) gradient overflow

This is to fix deepspeedai#4851

Convergence test:

- llama-2-7b random weights, using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.

> zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl:
5.151907920837402, loss: 1.6393671035766602
> hpz with this PR: ppl: 5.081737518310547, loss: 1.6256532669067383

- llama-2-7b pretrained weights with lora, using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh.

> zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl:
1.8326854705810547, loss: 0.6057823896408081
> hpz with this PR: ppl: 1.8326854705810547, loss: 0.6057823896408081

Performance test on 32 V100, still using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.
- gradient accumulation step = 8
> master branch with hpz: SamplesPerSec=17.567813158654847
> this patch with hpz: SamplesPerSec=24.121657876029225
- lora
> master branch with hpz: SamplesPerSec=33.88883430864484
> this patch with hpz: SamplesPerSec=43.39463460004735

---------

Co-authored-by: Michael Wyatt <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] convergence issues with zero_hpz_partition_size
5 participants