Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_vision_available,
)

Expand Down Expand Up @@ -464,6 +465,11 @@ def require_torch_tpu(test_case):
jax_device = None


def require_torchdynamo(test_case):
"""Decorator marking a test that requires TorchDynamo"""
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
Expand Down
34 changes: 31 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
)
from .utils.generic import ContextManagers


_is_torch_generator_available = False
Expand Down Expand Up @@ -2172,6 +2174,32 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s

return inputs

def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
"""
return ContextManagers(
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
]
)

def torchdynamo_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
ctx_manager = contextlib.nullcontext()
if is_torchdynamo_available():
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

if self.args.torchdynamo == "eager":
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
return ctx_manager

def autocast_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
Expand Down Expand Up @@ -2213,7 +2241,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.autocast_smart_context_manager():
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)

if self.args.n_gpu > 1:
Expand Down Expand Up @@ -2907,7 +2935,7 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
with self.autocast_smart_context_manager():
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()

Expand All @@ -2917,7 +2945,7 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
with self.autocast_smart_context_manager():
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def prediction_step(
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

with torch.no_grad():
with self.autocast_smart_context_manager():
with self.compute_loss_context_manager():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ class TrainingArguments:
full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training
torchdynamo (`str`, *optional*):
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
"nvfuser]. This is an experimental API and subject to change.
"""

output_dir: str = field(
Expand Down Expand Up @@ -881,6 +884,20 @@ class TrainingArguments:
)
},
)
torchdynamo: Optional[str] = field(
default=None,
metadata={
"help": (
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
"choices": ["eager", "nvfuser"],
},
)

def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def is_torch_tpu_available():
return importlib.util.find_spec("torch_xla.core.xla_model") is not None


def is_torchdynamo_available():
return importlib.util.find_spec("torchdynamo") is not None


def is_datasets_available():
return _datasets_available

Expand Down
95 changes: 95 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_torchdynamo,
require_wandb,
slow,
)
Expand Down Expand Up @@ -1594,6 +1595,100 @@ def test_fp16_full_eval(self):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)

@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_full_eval(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()

bs = 8
eval_len = 16 * n_gpus
# make the params are somewhat big so that there will be enough RAM consumed to be able to
# measure things. We should get about 64KB for a+b in fp32
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001

# 1. Default - without TorchDynamo
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
metrics = trainer.evaluate()
original_eval_loss = metrics["eval_loss"]
del trainer

# 2. TorchDynamo eager
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
del trainer

# 3. TorchDynamo nvfuser
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)

@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
x = inputs["x"]
output = model(x)
if self.args.n_gpu == 1:
return output.mean()
return output

class MyModule(torch.nn.Module):
"""Simple module that does aggressive fusion"""

def __init__(self):
super().__init__()

def forward(self, x):
for _ in range(20):
Copy link
Contributor Author

@anijain2305 anijain2305 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 Can you try changing this number to 10 to see if the speed improves? If it doesn't, lets make it 1 to see if this is the culprit. At 1, the test will fail, but we will have little more info to debug,

Copy link
Contributor

@stas00 stas00 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 1 or 2 fails as expected
  • 5 succeeds
  • 10 hangs - same top of the stack

Copy link
Contributor

@stas00 stas00 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured the trigger out I think- Trainer tried to run 2 gpus DP and somehow was leading to hanging.

Forcing running the test on one gpu breaks past this hanging issue

CUDA_VISIBLE_DEVICES=0 pyt tests/trainer/test_trainer.py -k torchdynamo_memory -sv

now fails:

self = <tests.trainer.test_trainer.TrainerIntegrationTest testMethod=test_torchdynamo_memory>

    @require_torch_gpu
    @require_torchdynamo
    def test_torchdynamo_memory(self):
        class MyModule(torch.nn.Module):
            """Simple module that does aggressive fusion"""
    
            def __init__(self):
                super().__init__()
    
            def forward(self, x):
                for _ in range(20):
                    x = torch.nn.functional.relu(x)
                return x
    
        mod = MyModule()
    
        # 1. Default - without TorchDynamo
        a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
        a.grad = None
        trainer = Trainer(model=mod)
        # warmup
        for _ in range(10):
>           orig_loss = trainer.training_step(mod, {"x": a})

tests/trainer/test_trainer.py:1649: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/trainer.py:2263: in training_step
    loss.backward()
/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/_tensor.py:399: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/autograd/__init__.py:166: in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

outputs = (tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0',
       grad_fn=<SelectBackward0>),), grads = (None,)
is_grads_batched = False

    def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
                    is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
        new_grads: List[_OptionalTensor] = []
        for out, grad in zip(outputs, grads):
            if isinstance(grad, torch.Tensor):
                grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
                if not out.shape == grad_shape:
                    if is_grads_batched:
                        raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
                                           "dimension of each grad_output as the batch dimension. "
                                           "The sizes of the remaining dimensions are expected to match "
                                           "the shape of corresponding output, but a mismatch "
                                           "was detected: grad_output["
                                           + str(grads.index(grad)) + "] has a shape of "
                                           + str(grad.shape) + " and output["
                                           + str(outputs.index(out)) + "] has a shape of "
                                           + str(out.shape) + ". "
                                           "If you only want some tensors in `grad_output` to be considered "
                                           "batched, consider using vmap.")
                    else:
                        raise RuntimeError("Mismatch in shape: grad_output["
                                           + str(grads.index(grad)) + "] has a shape of "
                                           + str(grad.shape) + " and output["
                                           + str(outputs.index(out)) + "] has a shape of "
                                           + str(out.shape) + ".")
                if out.dtype.is_complex != grad.dtype.is_complex:
                    raise RuntimeError("For complex Tensors, both grad_output and output"
                                       " are required to have the same dtype."
                                       " Mismatch in dtype: grad_output["
                                       + str(grads.index(grad)) + "] has a dtype of "
                                       + str(grad.dtype) + " and output["
                                       + str(outputs.index(out)) + "] has a dtype of "
                                       + str(out.dtype) + ".")
                new_grads.append(grad)
            elif grad is None:
                if out.requires_grad:
                    if out.numel() != 1:
>                       raise RuntimeError("grad can be implicitly created only for scalar outputs")
E                       RuntimeError: grad can be implicitly created only for scalar outputs

/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/autograd/__init__.py:67: RuntimeError

Copy link
Contributor Author

@anijain2305 anijain2305 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. For single GPU, the training_step was not calling mean() on the output. PyTorch expects a scalar loss for .backward() call and therefore you saw the error message you pasted. I just added a CustomTrainer to reduce to scalar and the test passes for single GPU as well.

I am not sure why it hangs for DP though (maybe its compiling for each GPU node, and thus compilation time is shooting up?).

Is it possible to limit the TorchDynamo usage for single GPUs only? We have not really tested TorchDynamo for distributed training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's because I have a lopsided setup?

GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA GeForce GTX 1070 Ti

Is the compiled version hardware agnostic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works on 2 gpus with DP as well - should we remove the restrictions then and put it back how you coded it originally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, awesome. Thanks for putting time on this.

There is no need to change the test back to original one. My commit earlier extended the test to single-GPU as well. So, the test works for both single and multi-GPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except due to the @require_torch_non_multi_gpu it will now only ever will be run on a single gpu.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hang seems like concerning behavior. Glad it seems to have sorted itself out, but please let us know if it returns in any form. I do have a single mixed GPU system I could try reproducing on if it comes back. Though I would always recommend DDP running across matching GPUs.

Copy link
Contributor

@stas00 stas00 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hanging was on a single GPU as well with pytorch nightly from 05-18.

x = torch.nn.functional.relu(x)
return x

mod = MyModule()

# 1. Default - without TorchDynamo
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
trainer = CustomTrainer(model=mod)
# warmup
for _ in range(10):
orig_loss = trainer.training_step(mod, {"x": a})

torch.cuda.reset_peak_memory_stats()
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer

# Reset the peak for another measurement
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# 2. TorchDynamo nvfuser
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
trainer = CustomTrainer(model=mod, args=args)
# warmup
for _ in range(10):
loss = trainer.training_step(mod, {"x": a})

torch.cuda.reset_peak_memory_stats()
loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
del trainer

# Functional check
self.assertAlmostEqual(loss, orig_loss)

# AOT Autograd recomputaion and nvfuser recomputation optimization
# aggressively fuses the operations and reduce the memory footprint.
self.assertGreater(orig_peak_mem, peak_mem * 2)

@require_torch_gpu
@require_torch_bf16
def test_bf16_full_eval(self):
Expand Down