-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Support compilation via Torchdynamo, AOT Autograd, NVFuser #17308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2396218
0b4c279
52503f2
298b3ba
6756fa5
28f80ec
bf41704
7925fa3
ce47619
9daaf93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,7 @@ | |
| require_torch_non_multi_gpu, | ||
| require_torch_tf32, | ||
| require_torch_up_to_2_gpus, | ||
| require_torchdynamo, | ||
| require_wandb, | ||
| slow, | ||
| ) | ||
|
|
@@ -1594,6 +1595,100 @@ def test_fp16_full_eval(self): | |
| # perfect world: fp32_init/2 == fp16_eval | ||
| self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) | ||
|
|
||
| @require_torch_non_multi_gpu | ||
| @require_torchdynamo | ||
| def test_torchdynamo_full_eval(self): | ||
| # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu | ||
| n_gpus = get_gpu_count() | ||
|
|
||
| bs = 8 | ||
| eval_len = 16 * n_gpus | ||
| # make the params are somewhat big so that there will be enough RAM consumed to be able to | ||
| # measure things. We should get about 64KB for a+b in fp32 | ||
| a = torch.ones(1000, bs) + 0.001 | ||
| b = torch.ones(1000, bs) - 0.001 | ||
|
|
||
| # 1. Default - without TorchDynamo | ||
| trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len) | ||
| metrics = trainer.evaluate() | ||
| original_eval_loss = metrics["eval_loss"] | ||
| del trainer | ||
|
|
||
| # 2. TorchDynamo eager | ||
| trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager") | ||
| metrics = trainer.evaluate() | ||
| self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) | ||
| del trainer | ||
|
|
||
| # 3. TorchDynamo nvfuser | ||
| trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser") | ||
| metrics = trainer.evaluate() | ||
| self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) | ||
|
|
||
| @require_torch_non_multi_gpu | ||
| @require_torchdynamo | ||
| def test_torchdynamo_memory(self): | ||
| # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu | ||
| class CustomTrainer(Trainer): | ||
| def compute_loss(self, model, inputs, return_outputs=False): | ||
| x = inputs["x"] | ||
| output = model(x) | ||
| if self.args.n_gpu == 1: | ||
| return output.mean() | ||
| return output | ||
|
|
||
| class MyModule(torch.nn.Module): | ||
| """Simple module that does aggressive fusion""" | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, x): | ||
| for _ in range(20): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stas00 Can you try changing this number to 10 to see if the speed improves? If it doesn't, lets make it 1 to see if this is the culprit. At 1, the test will fail, but we will have little more info to debug, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I figured the trigger out I think- Trainer tried to run 2 gpus DP and somehow was leading to hanging. Forcing running the test on one gpu breaks past this hanging issue now fails: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. For single GPU, the I am not sure why it hangs for DP though (maybe its compiling for each GPU node, and thus compilation time is shooting up?). Is it possible to limit the TorchDynamo usage for single GPUs only? We have not really tested TorchDynamo for distributed training. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps it's because I have a lopsided setup? GPU models and configuration: Is the compiled version hardware agnostic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it works on 2 gpus with DP as well - should we remove the restrictions then and put it back how you coded it originally? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, awesome. Thanks for putting time on this. There is no need to change the test back to original one. My commit earlier extended the test to single-GPU as well. So, the test works for both single and multi-GPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. except due to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hang seems like concerning behavior. Glad it seems to have sorted itself out, but please let us know if it returns in any form. I do have a single mixed GPU system I could try reproducing on if it comes back. Though I would always recommend DDP running across matching GPUs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hanging was on a single GPU as well with pytorch nightly from 05-18. |
||
| x = torch.nn.functional.relu(x) | ||
| return x | ||
|
|
||
| mod = MyModule() | ||
|
|
||
| # 1. Default - without TorchDynamo | ||
| a = torch.ones(1024, 1024, device="cuda", requires_grad=True) | ||
| a.grad = None | ||
| trainer = CustomTrainer(model=mod) | ||
| # warmup | ||
| for _ in range(10): | ||
| orig_loss = trainer.training_step(mod, {"x": a}) | ||
|
|
||
| torch.cuda.reset_peak_memory_stats() | ||
| orig_loss = trainer.training_step(mod, {"x": a}) | ||
| orig_peak_mem = torch.cuda.max_memory_allocated() | ||
| del trainer | ||
|
|
||
| # Reset the peak for another measurement | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.reset_peak_memory_stats() | ||
|
|
||
| # 2. TorchDynamo nvfuser | ||
| a = torch.ones(1024, 1024, device="cuda", requires_grad=True) | ||
| a.grad = None | ||
| args = TrainingArguments(output_dir="None", torchdynamo="nvfuser") | ||
| trainer = CustomTrainer(model=mod, args=args) | ||
| # warmup | ||
| for _ in range(10): | ||
| loss = trainer.training_step(mod, {"x": a}) | ||
|
|
||
| torch.cuda.reset_peak_memory_stats() | ||
| loss = trainer.training_step(mod, {"x": a}) | ||
| peak_mem = torch.cuda.max_memory_allocated() | ||
| del trainer | ||
|
|
||
| # Functional check | ||
| self.assertAlmostEqual(loss, orig_loss) | ||
|
|
||
| # AOT Autograd recomputaion and nvfuser recomputation optimization | ||
| # aggressively fuses the operations and reduce the memory footprint. | ||
| self.assertGreater(orig_peak_mem, peak_mem * 2) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_bf16 | ||
| def test_bf16_full_eval(self): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.