From 4a0fe7f6d80e1d644e2c8f03860d0f3b4b8e1c2b Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 17 Jun 2022 17:12:52 -0700 Subject: [PATCH 01/24] enable fx2trt --- src/transformers/trainer.py | 5 +++++ src/transformers/training_args.py | 2 +- tests/trainer/test_trainer.py | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 70d11d2fb346..838c8780763c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2251,12 +2251,17 @@ def torchdynamo_smart_context_manager(self): ctx_manager = contextlib.nullcontext() if is_torchdynamo_available(): import torchdynamo + from torchdynamo.optimizations import backends from torchdynamo.optimizations.training import aot_autograd_speedup_strategy if self.args.torchdynamo == "eager": ctx_manager = torchdynamo.optimize("eager") elif self.args.torchdynamo == "nvfuser": ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy) + elif self.args.torchdynamo == "fx2trt-fp16": + ctx_manager = torchdynamo.optimize(backends.fx2trt_compiler_fp16) + elif self.args.torchdynamo == "fx2trt": + ctx_manager = torchdynamo.optimize(backends.fx2trt_compiler) return ctx_manager def autocast_smart_context_manager(self): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5ef4d8d08edd..47c99a5c7dae 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -928,7 +928,7 @@ class TrainingArguments: " are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging." " nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models." ), - "choices": ["eager", "nvfuser"], + "choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"], }, ) ray_scope: Optional[str] = field( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4bae763cc4e8..a4041e9a7c3b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1748,6 +1748,21 @@ def test_torchdynamo_full_eval(self): metrics = trainer.evaluate() self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + # 4. TorchDynamo fx2trt + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt") + metrics = trainer.evaluate() + t1 = metrics["eval_loss"] + t2 = original_eval_loss + self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + + # 5. TorchDynamo fx2trt-fp16 + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16") + metrics = trainer.evaluate() + t1 = metrics["eval_loss"] + t2 = original_eval_loss + # fp16 has accuracy accuracy degradation + self.assertLess(np.max(np.abs(t1 - t2)), 1e-3) + @require_torch_non_multi_gpu @require_torchdynamo def test_torchdynamo_memory(self): From 3cf0eab3b545df0052b14a4951519e44e38ceebd Mon Sep 17 00:00:00 2001 From: Wei Date: Tue, 28 Jun 2022 19:25:43 -0700 Subject: [PATCH 02/24] Update perf_train_gpu_one.mdx --- docs/source/en/perf_train_gpu_one.mdx | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/en/perf_train_gpu_one.mdx b/docs/source/en/perf_train_gpu_one.mdx index 5b120391e29a..7b78b1805e99 100644 --- a/docs/source/en/perf_train_gpu_one.mdx +++ b/docs/source/en/perf_train_gpu_one.mdx @@ -718,3 +718,11 @@ For some applications, such as pretraining large language models, applying all t Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many). +## Inference with torchdynamo +TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose from one of it. +``` +TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost +TrainingArguments(torchdynamo="nvfuser") #enable nvfuser +TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32 +TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16 +``` From fb92f4cc2457e1c220a75e4ff013074c676545d0 Mon Sep 17 00:00:00 2001 From: Wei Date: Tue, 28 Jun 2022 19:27:01 -0700 Subject: [PATCH 03/24] Update perf_train_gpu_one.mdx --- docs/source/en/perf_train_gpu_one.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_train_gpu_one.mdx b/docs/source/en/perf_train_gpu_one.mdx index 7b78b1805e99..ab01d0019a6b 100644 --- a/docs/source/en/perf_train_gpu_one.mdx +++ b/docs/source/en/perf_train_gpu_one.mdx @@ -719,7 +719,7 @@ For some applications, such as pretraining large language models, applying all t Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many). ## Inference with torchdynamo -TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose from one of it. +TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost. ``` TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost TrainingArguments(torchdynamo="nvfuser") #enable nvfuser From e78f9d5759ee27975c2465b149d0bc40d0c4d21d Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 28 Jun 2022 20:02:41 -0700 Subject: [PATCH 04/24] add lib check --- src/transformers/testing_utils.py | 4 ++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 2 ++ tests/trainer/test_trainer.py | 3 +++ 4 files changed, 10 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1aebe8f4e2de..746d951cfe67 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -72,6 +72,7 @@ is_torch_tpu_available, is_torchaudio_available, is_torchdynamo_available, + is_torch_tensorrt_fx_available, is_vision_available, ) @@ -480,6 +481,9 @@ def require_torchdynamo(test_case): """Decorator marking a test that requires TorchDynamo""" return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) +def require_torch_tensorrt_fx(test_case): + """Decorator marking a test that requires Torch-TensorRT FX""" + return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index fea13ff47cc8..1ffe2b60721c 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -133,6 +133,7 @@ is_torch_tpu_available, is_torchaudio_available, is_torchdynamo_available, + is_torch_tensorrt_fx_available, is_training_run_on_sagemaker, is_vision_available, requires_backends, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 53f7515bca5f..a744fe8beb9d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -390,6 +390,8 @@ def is_torch_tpu_available(): def is_torchdynamo_available(): return importlib.util.find_spec("torchdynamo") is not None +def is_torch_tensorrt_fx_available(): + return importlib.util.find_spec("torch_tensorrt.fx") is not None def is_datasets_available(): return _datasets_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a4041e9a7c3b..b7da60fb03f8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -64,6 +64,7 @@ require_torch_tf32, require_torch_up_to_2_gpus, require_torchdynamo, + require_torch_tensorrt_fx, require_wandb, slow, ) @@ -1720,6 +1721,7 @@ def test_fp16_full_eval(self): @require_torch_non_multi_gpu @require_torchdynamo + @require_torch_tensorrt_fx def test_torchdynamo_full_eval(self): # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu n_gpus = get_gpu_count() @@ -1765,6 +1767,7 @@ def test_torchdynamo_full_eval(self): @require_torch_non_multi_gpu @require_torchdynamo + @require_torch_tensorrt_fx def test_torchdynamo_memory(self): # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu class CustomTrainer(Trainer): From 19647846adc06aed11f286df7011853659b5300c Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 29 Jun 2022 16:43:50 -0700 Subject: [PATCH 05/24] update --- tests/trainer/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b7da60fb03f8..7ec1bbd42b74 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1765,6 +1765,7 @@ def test_torchdynamo_full_eval(self): # fp16 has accuracy accuracy degradation self.assertLess(np.max(np.abs(t1 - t2)), 1e-3) + @require_torch_non_multi_gpu @require_torchdynamo @require_torch_tensorrt_fx From db121a8d8b059286b37182118ee0dc4d73cbc159 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 30 Jun 2022 14:44:00 -0700 Subject: [PATCH 06/24] format --- src/transformers/testing_utils.py | 2 ++ src/transformers/utils/import_utils.py | 2 ++ tests/trainer/test_trainer.py | 1 - 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 746d951cfe67..2ba4daa334cd 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -481,10 +481,12 @@ def require_torchdynamo(test_case): """Decorator marking a test that requires TorchDynamo""" return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) + def require_torch_tensorrt_fx(test_case): """Decorator marking a test that requires Torch-TensorRT FX""" return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a744fe8beb9d..7504b4eb186c 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -390,9 +390,11 @@ def is_torch_tpu_available(): def is_torchdynamo_available(): return importlib.util.find_spec("torchdynamo") is not None + def is_torch_tensorrt_fx_available(): return importlib.util.find_spec("torch_tensorrt.fx") is not None + def is_datasets_available(): return _datasets_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7ec1bbd42b74..b7da60fb03f8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1765,7 +1765,6 @@ def test_torchdynamo_full_eval(self): # fp16 has accuracy accuracy degradation self.assertLess(np.max(np.abs(t1 - t2)), 1e-3) - @require_torch_non_multi_gpu @require_torchdynamo @require_torch_tensorrt_fx From 976c7db636f626f94379eb05abf55bcc6c252e15 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 30 Jun 2022 16:42:46 -0700 Subject: [PATCH 07/24] update --- src/transformers/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 188d9f9d420b..7bdd1584b7ad 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -141,6 +141,7 @@ is_sagemaker_mp_enabled, is_torch_tpu_available, is_torchdynamo_available, + is_torch_tensorrt_fx_available, logging, ) from .utils.generic import ContextManagers @@ -2291,7 +2292,7 @@ def torchdynamo_smart_context_manager(self): A helper wrapper that creates an appropriate context manager for `torchdynamo`. """ ctx_manager = contextlib.nullcontext() - if is_torchdynamo_available(): + if is_torchdynamo_available() and is_torch_tensorrt_fx_available(): import torchdynamo from torchdynamo.optimizations import backends from torchdynamo.optimizations.training import aot_autograd_speedup_strategy From fdd10b0e63747d57695fbffd94cde0d765ad7b85 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 30 Jun 2022 17:00:24 -0700 Subject: [PATCH 08/24] fix import check --- src/transformers/utils/import_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 9d5ac8879ffa..454de0caa702 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -419,6 +419,8 @@ def is_torchdynamo_available(): def is_torch_tensorrt_fx_available(): + if importlib.util.find_spec("torch_tensorrt") is None: + return False return importlib.util.find_spec("torch_tensorrt.fx") is not None From a1b9b9ed32ea8088ceaedb6eecaca4e92418f1aa Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 30 Jun 2022 17:43:09 -0700 Subject: [PATCH 09/24] fix isort --- src/transformers/testing_utils.py | 2 +- src/transformers/trainer.py | 2 +- src/transformers/utils/__init__.py | 2 +- tests/trainer/test_trainer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index dd5eb9178c39..564d6364d259 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -71,11 +71,11 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, is_torchdynamo_available, - is_torch_tensorrt_fx_available, is_vision_available, ) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7bdd1584b7ad..a939d8efd136 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -139,9 +139,9 @@ is_ipex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, + is_torch_tensorrt_fx_available, is_torch_tpu_available, is_torchdynamo_available, - is_torch_tensorrt_fx_available, logging, ) from .utils.generic import ContextManagers diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index ab417b99471c..1ee4521514af 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -132,11 +132,11 @@ is_torch_fx_available, is_torch_fx_proxy, is_torch_onnx_dict_inputs_support_available, + is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, is_torchdynamo_available, - is_torch_tensorrt_fx_available, is_training_run_on_sagemaker, is_vision_available, requires_backends, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1ad261b356b5..3f8557cca061 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,10 +62,10 @@ require_torch_gpu, require_torch_multi_gpu, require_torch_non_multi_gpu, + require_torch_tensorrt_fx, require_torch_tf32, require_torch_up_to_2_gpus, require_torchdynamo, - require_torch_tensorrt_fx, require_wandb, slow, ) From 10fd918266937c22b36f1f8a67c5ee41861d64f5 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 5 Jul 2022 23:00:05 -0700 Subject: [PATCH 10/24] improve doc --- docs/source/en/perf_train_gpu_one.mdx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/en/perf_train_gpu_one.mdx b/docs/source/en/perf_train_gpu_one.mdx index ab01d0019a6b..779be029ec30 100644 --- a/docs/source/en/perf_train_gpu_one.mdx +++ b/docs/source/en/perf_train_gpu_one.mdx @@ -11,7 +11,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o # Efficient Training on a Single GPU -This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many). +This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many). In this section we have a look at a few tricks to reduce the memory footprint and speed up training for large models and how they are integrated in the [`Trainer`] and [🤗 Accelerate](https://huggingface.co/docs/accelerate/). Each method can improve speed or memory usage which is summarized in the table below: @@ -367,7 +367,7 @@ Samples/second: 10.09 GPU memory occupied: 7275 MB. ``` -We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster. +We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster. ### BF16 If you have access to a Ampere or newer hardware you can use bf16 for your training and evaluation. While bf16 has a worse precision than fp16, it has a much much bigger dynamic range. Therefore, if in the past you were experiencing overflow issues while training the model, bf16 will prevent this from happening most of the time. Remember that in fp16 the biggest number you can have is `65535` and any number above that will overflow. A bf16 number can be as large as `3.39e+38` (!) which is about the same as fp32 - because both have 8-bits used for the numerical range. @@ -394,7 +394,7 @@ Like all cases with reduced precision this may or may not be satisfactory for yo If you're already using fp16 or bf16 mixed precision it may help with the throughput as well. -You can enable this mode in the 🤗 Trainer with: +You can enable this mode in the 🤗 Trainer with: ```python TrainingArguments(tf32=True) ``` @@ -654,7 +654,7 @@ https://github.com/huggingface/transformers/blob/master/src/transformers/trainer ## Choice of GPU -Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture. +Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture. Now, let's take a step back and discuss what we should optimize for when scaling the training of large models. @@ -719,7 +719,7 @@ For some applications, such as pretraining large language models, applying all t Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many). ## Inference with torchdynamo -TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost. +TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost. ``` TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost TrainingArguments(torchdynamo="nvfuser") #enable nvfuser From 6c6835d4f5be87957d3cf7701a7b7545f92c324c Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 5 Jul 2022 23:25:24 -0700 Subject: [PATCH 11/24] refactor ctx manager --- src/transformers/trainer.py | 16 +--------------- src/transformers/training_args.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a939d8efd136..6352743db992 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2291,21 +2291,7 @@ def torchdynamo_smart_context_manager(self): """ A helper wrapper that creates an appropriate context manager for `torchdynamo`. """ - ctx_manager = contextlib.nullcontext() - if is_torchdynamo_available() and is_torch_tensorrt_fx_available(): - import torchdynamo - from torchdynamo.optimizations import backends - from torchdynamo.optimizations.training import aot_autograd_speedup_strategy - - if self.args.torchdynamo == "eager": - ctx_manager = torchdynamo.optimize("eager") - elif self.args.torchdynamo == "nvfuser": - ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy) - elif self.args.torchdynamo == "fx2trt-fp16": - ctx_manager = torchdynamo.optimize(backends.fx2trt_compiler_fp16) - elif self.args.torchdynamo == "fx2trt": - ctx_manager = torchdynamo.optimize(backends.fx2trt_compiler) - return ctx_manager + return self.args.ctx_manager_torchdynamo def autocast_smart_context_manager(self): """ diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ba87c0303f61..1fddab9800c3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1218,6 +1218,29 @@ def __post_init__(self): FutureWarning, ) + if self.torchdynamo: + if not is_torchdynamo_available(): + raise RuntimeError("Torchdynamo is not installed.") + + import torchdynamo + from torchdynamo.optimizations import backends + from torchdynamo.optimizations.training import aot_autograd_speedup_strategy + + if self.torchdynamo == "eager": + self.ctx_manager_torchdynamo = torchdynamo.optimize("eager") + elif self.torchdynamo == "nvfuser": + self.ctx_manager_torchdynamo = torchdynamo.optimize(aot_autograd_speedup_strategy) + elif self.torchdynamo == "fx2trt-fp16": + if not is_torch_tensorrt_fx_available(): + raise RuntimeError("Torch-TensorRT FX path is not installed.") + self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler_fp16) + elif self.torchdynamo == "fx2trt": + if not is_torch_tensorrt_fx_available(): + raise RuntimeError("Torch-TensorRT FX path is not installed.") + self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler) + else: + raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.") + def __str__(self): self_as_dict = asdict(self) From 0d10b6b8f98056151ebc476ab138ff0917a1b1d1 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 5 Jul 2022 23:42:53 -0700 Subject: [PATCH 12/24] fix isort --- src/transformers/trainer.py | 33 +++++++++++++------------------ src/transformers/training_args.py | 3 +++ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6352743db992..61d721fed450 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -32,9 +32,23 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import torch +import torch.distributed as dist +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler from tqdm.auto import tqdm +from huggingface_hub import Repository +from . import __version__ +from .configuration_utils import PretrainedConfig +from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .debug_utils import DebugOption, DebugUnderflowOverflow +from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled +from .dependency_versions_check import dep_version_check # Integrations must be imported before ML frameworks: from .integrations import ( # isort: split default_hp_search_backend, @@ -50,23 +64,6 @@ run_hp_search_sigopt, run_hp_search_wandb, ) - -import numpy as np -import torch -import torch.distributed as dist -from packaging import version -from torch import nn -from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler -from torch.utils.data.distributed import DistributedSampler - -from huggingface_hub import Repository - -from . import __version__ -from .configuration_utils import PretrainedConfig -from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from .debug_utils import DebugOption, DebugUnderflowOverflow -from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled -from .dependency_versions_check import dep_version_check from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .optimization import Adafactor, get_scheduler @@ -139,9 +136,7 @@ is_ipex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - is_torch_tensorrt_fx_available, is_torch_tpu_available, - is_torchdynamo_available, logging, ) from .utils.generic import ContextManagers diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1fddab9800c3..78e79adddca5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -41,8 +41,10 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, + is_torchdynamo_available, logging, torch_required, ) @@ -1218,6 +1220,7 @@ def __post_init__(self): FutureWarning, ) + self.ctx_manager_torchdynamo = contextlib.nullcontext() if self.torchdynamo: if not is_torchdynamo_available(): raise RuntimeError("Torchdynamo is not installed.") From f1b6f30a7ef1da73ec2ba323e9e3c9033bf35376 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 6 Jul 2022 00:00:25 -0700 Subject: [PATCH 13/24] black format --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 61d721fed450..9ad6d182c671 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -49,6 +49,7 @@ from .debug_utils import DebugOption, DebugUnderflowOverflow from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled from .dependency_versions_check import dep_version_check + # Integrations must be imported before ML frameworks: from .integrations import ( # isort: split default_hp_search_backend, From 8312edd94126ccf4703efd0e7652ae8c75aa3325 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 6 Jul 2022 09:40:06 -0700 Subject: [PATCH 14/24] isort fix --- src/transformers/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9ad6d182c671..61d721fed450 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -49,7 +49,6 @@ from .debug_utils import DebugOption, DebugUnderflowOverflow from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled from .dependency_versions_check import dep_version_check - # Integrations must be imported before ML frameworks: from .integrations import ( # isort: split default_hp_search_backend, From 066e069698477d948ddc3e52f6eb2bf3ffc84786 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 6 Jul 2022 12:01:12 -0700 Subject: [PATCH 15/24] fix format --- src/transformers/trainer.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 61d721fed450..263a7846f55f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -32,23 +32,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np -import torch -import torch.distributed as dist -from packaging import version -from torch import nn -from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler -from torch.utils.data.distributed import DistributedSampler from tqdm.auto import tqdm -from huggingface_hub import Repository -from . import __version__ -from .configuration_utils import PretrainedConfig -from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from .debug_utils import DebugOption, DebugUnderflowOverflow -from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled -from .dependency_versions_check import dep_version_check # Integrations must be imported before ML frameworks: from .integrations import ( # isort: split default_hp_search_backend, @@ -64,6 +50,23 @@ run_hp_search_sigopt, run_hp_search_wandb, ) + +import numpy as np +import torch +import torch.distributed as dist +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +from huggingface_hub import Repository + +from . import __version__ +from .configuration_utils import PretrainedConfig +from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .debug_utils import DebugOption, DebugUnderflowOverflow +from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled +from .dependency_versions_check import dep_version_check from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .optimization import Adafactor, get_scheduler From 26496832552c0f66f60f7a4adabec1db19193cd4 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 8 Jul 2022 23:08:13 -0700 Subject: [PATCH 16/24] update args --- src/transformers/training_args.py | 35 ++++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 78e79adddca5..9cd7654f4d87 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1220,7 +1220,6 @@ def __post_init__(self): FutureWarning, ) - self.ctx_manager_torchdynamo = contextlib.nullcontext() if self.torchdynamo: if not is_torchdynamo_available(): raise RuntimeError("Torchdynamo is not installed.") @@ -1229,20 +1228,26 @@ def __post_init__(self): from torchdynamo.optimizations import backends from torchdynamo.optimizations.training import aot_autograd_speedup_strategy - if self.torchdynamo == "eager": - self.ctx_manager_torchdynamo = torchdynamo.optimize("eager") - elif self.torchdynamo == "nvfuser": - self.ctx_manager_torchdynamo = torchdynamo.optimize(aot_autograd_speedup_strategy) - elif self.torchdynamo == "fx2trt-fp16": - if not is_torch_tensorrt_fx_available(): - raise RuntimeError("Torch-TensorRT FX path is not installed.") - self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler_fp16) - elif self.torchdynamo == "fx2trt": - if not is_torch_tensorrt_fx_available(): - raise RuntimeError("Torch-TensorRT FX path is not installed.") - self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler) - else: - raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.") + def get_ctx(): + # Normal + if self.torchdynamo == "eager": + return torchdynamo.optimize("eager") + elif self.torchdynamo == "nvfuser": + return torchdynamo.optimize(aot_autograd_speedup_strategy) + # TensorRT + if self.torchdynamo in ["fx2trt-fp16", "fx2trt"]: + if not is_torch_tensorrt_fx_available(): + raise RuntimeError("Torch-TensorRT FX path is not installed.") + if self.torchdynamo == "fx2trt-fp16": + return torchdynamo.optimize(backends.fx2trt_compiler_fp16) + elif self.torchdynamo == "fx2trt": + return torchdynamo.optimize(backends.fx2trt_compiler) + else: + raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.") + + self.ctx_manager_torchdynamo = get_ctx() + else: + self.ctx_manager_torchdynamo = contextlib.nullcontext() def __str__(self): self_as_dict = asdict(self) From ec115e2d30cb7c4683e876c17eb08f6926110236 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 8 Jul 2022 23:10:29 -0700 Subject: [PATCH 17/24] update black --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 9cd7654f4d87..60d9e08b22cc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1239,9 +1239,9 @@ def get_ctx(): if not is_torch_tensorrt_fx_available(): raise RuntimeError("Torch-TensorRT FX path is not installed.") if self.torchdynamo == "fx2trt-fp16": - return torchdynamo.optimize(backends.fx2trt_compiler_fp16) + return torchdynamo.optimize(backends.fx2trt_compiler_fp16) elif self.torchdynamo == "fx2trt": - return torchdynamo.optimize(backends.fx2trt_compiler) + return torchdynamo.optimize(backends.fx2trt_compiler) else: raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.") From d0262aa6e81033da75d8f33dac593d61c1822df0 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 11 Jul 2022 19:51:35 -0700 Subject: [PATCH 18/24] cleanups --- tests/trainer/test_trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3f8557cca061..d821fd9e5bfe 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1846,7 +1846,6 @@ def test_torchdynamo_full_eval(self): @require_torch_non_multi_gpu @require_torchdynamo - @require_torch_tensorrt_fx def test_torchdynamo_memory(self): # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu class CustomTrainer(Trainer): @@ -1870,7 +1869,7 @@ def forward(self, x): mod = MyModule() - # 1. Default - without TorchDynamo + # 1. without TorchDynamo (eager baseline) a = torch.ones(1024, 1024, device="cuda", requires_grad=True) a.grad = None trainer = CustomTrainer(model=mod) @@ -1878,16 +1877,15 @@ def forward(self, x): for _ in range(10): orig_loss = trainer.training_step(mod, {"x": a}) + # resets + gc.collect() + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() + orig_loss = trainer.training_step(mod, {"x": a}) orig_peak_mem = torch.cuda.max_memory_allocated() del trainer - # Reset the peak for another measurement - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - # 2. TorchDynamo nvfuser a = torch.ones(1024, 1024, device="cuda", requires_grad=True) a.grad = None @@ -1897,7 +1895,11 @@ def forward(self, x): for _ in range(10): loss = trainer.training_step(mod, {"x": a}) + # resets + gc.collect() + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() + loss = trainer.training_step(mod, {"x": a}) peak_mem = torch.cuda.max_memory_allocated() del trainer From e8bf0a55c4cd673c23394a12d3d4404b8dc6fc50 Mon Sep 17 00:00:00 2001 From: Wei Date: Tue, 12 Jul 2022 00:00:32 -0700 Subject: [PATCH 19/24] Update perf_train_gpu_one.mdx --- docs/source/en/perf_train_gpu_one.mdx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/perf_train_gpu_one.mdx b/docs/source/en/perf_train_gpu_one.mdx index b0d5d74ecaae..0c130b417223 100644 --- a/docs/source/en/perf_train_gpu_one.mdx +++ b/docs/source/en/perf_train_gpu_one.mdx @@ -726,3 +726,7 @@ TrainingArguments(torchdynamo="nvfuser") #enable nvfuser TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32 TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16 ``` +This feature involves 3 different libraries. To install them, please follow the instructions below: +- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup) +- [Functorch installation](https://github.com/pytorch/functorch#install) +- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation) From 6b1d03b949d6132b882f6c242e12159200cbddd1 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 12 Jul 2022 09:30:14 -0700 Subject: [PATCH 20/24] code refactor --- src/transformers/integrations.py | 25 +++++++++++++++++++++++++ src/transformers/training_args.py | 25 ++----------------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index bf2968ed96e4..f98435e478f3 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -1015,3 +1015,28 @@ def get_reporting_integration_callbacks(report_to): f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported." ) return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to] + + +def get_torchdynamo_ctx(torchdynamo_str): + import torchdynamo + from torchdynamo.optimizations import backends + from torchdynamo.optimizations.training import aot_autograd_speedup_strategy + + def get_ctx(): + # Normal + if torchdynamo_str == "eager": + return torchdynamo.optimize("eager") + elif torchdynamo_str == "nvfuser": + return torchdynamo.optimize(aot_autograd_speedup_strategy) + # TensorRT + if torchdynamo_str in ["fx2trt-fp16", "fx2trt"]: + if not is_torch_tensorrt_fx_available(): + raise RuntimeError("Torch-TensorRT FX path is not installed.") + if torchdynamo_str == "fx2trt-fp16": + return torchdynamo.optimize(backends.fx2trt_compiler_fp16) + elif torchdynamo_str == "fx2trt": + return torchdynamo.optimize(backends.fx2trt_compiler) + else: + raise RuntimeError(f"Torchdynamo backend {torchdynamo_str} is not supported.") + + return get_ctx() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c9608d3852a9..11ca41a8e871 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional, Union from .debug_utils import DebugOption +from .integrations import get_torchdynamo_ctx from .trainer_utils import ( EvaluationStrategy, FSDPOption, @@ -1223,29 +1224,7 @@ def __post_init__(self): if self.torchdynamo: if not is_torchdynamo_available(): raise RuntimeError("Torchdynamo is not installed.") - - import torchdynamo - from torchdynamo.optimizations import backends - from torchdynamo.optimizations.training import aot_autograd_speedup_strategy - - def get_ctx(): - # Normal - if self.torchdynamo == "eager": - return torchdynamo.optimize("eager") - elif self.torchdynamo == "nvfuser": - return torchdynamo.optimize(aot_autograd_speedup_strategy) - # TensorRT - if self.torchdynamo in ["fx2trt-fp16", "fx2trt"]: - if not is_torch_tensorrt_fx_available(): - raise RuntimeError("Torch-TensorRT FX path is not installed.") - if self.torchdynamo == "fx2trt-fp16": - return torchdynamo.optimize(backends.fx2trt_compiler_fp16) - elif self.torchdynamo == "fx2trt": - return torchdynamo.optimize(backends.fx2trt_compiler) - else: - raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.") - - self.ctx_manager_torchdynamo = get_ctx() + self.ctx_manager_torchdynamo = get_torchdynamo_ctx(self.torchdynamo) else: self.ctx_manager_torchdynamo = contextlib.nullcontext() From 599e7aa3d025b317e4f5ac8adab3dc5acd28acb5 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 12 Jul 2022 23:27:20 -0700 Subject: [PATCH 21/24] code refactor to init --- src/transformers/integrations.py | 25 ----------------------- src/transformers/trainer.py | 33 ++++++++++++++++++++++++++++++- src/transformers/training_args.py | 9 --------- 3 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index f98435e478f3..bf2968ed96e4 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -1015,28 +1015,3 @@ def get_reporting_integration_callbacks(report_to): f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported." ) return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to] - - -def get_torchdynamo_ctx(torchdynamo_str): - import torchdynamo - from torchdynamo.optimizations import backends - from torchdynamo.optimizations.training import aot_autograd_speedup_strategy - - def get_ctx(): - # Normal - if torchdynamo_str == "eager": - return torchdynamo.optimize("eager") - elif torchdynamo_str == "nvfuser": - return torchdynamo.optimize(aot_autograd_speedup_strategy) - # TensorRT - if torchdynamo_str in ["fx2trt-fp16", "fx2trt"]: - if not is_torch_tensorrt_fx_available(): - raise RuntimeError("Torch-TensorRT FX path is not installed.") - if torchdynamo_str == "fx2trt-fp16": - return torchdynamo.optimize(backends.fx2trt_compiler_fp16) - elif torchdynamo_str == "fx2trt": - return torchdynamo.optimize(backends.fx2trt_compiler) - else: - raise RuntimeError(f"Torchdynamo backend {torchdynamo_str} is not supported.") - - return get_ctx() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 94f3f10fc0d3..336a10e00415 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -142,6 +142,8 @@ is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available, + is_torchdynamo_available, + is_torch_tensorrt_fx_available, logging, ) from .utils.generic import ContextManagers @@ -597,6 +599,35 @@ def __init__( # very last self._memory_tracker.stop_and_update_metrics() + # torchdynamo + if self.args.torchdynamo: + if not is_torchdynamo_available(): + raise RuntimeError("Torchdynamo is not installed.") + import torchdynamo + from torchdynamo.optimizations import backends + from torchdynamo.optimizations.training import aot_autograd_speedup_strategy + + def get_ctx(): + # Normal + if self.args.torchdynamo == "eager": + return torchdynamo.optimize("eager") + elif self.args.torchdynamo == "nvfuser": + return torchdynamo.optimize(aot_autograd_speedup_strategy) + # TensorRT + if self.args.torchdynamo in ["fx2trt-fp16", "fx2trt"]: + if not is_torch_tensorrt_fx_available(): + raise RuntimeError("Torch-TensorRT FX path is not installed.") + if self.args.torchdynamo == "fx2trt-fp16": + return torchdynamo.optimize(backends.fx2trt_compiler_fp16) + elif self.args.torchdynamo == "fx2trt": + return torchdynamo.optimize(backends.fx2trt_compiler) + else: + raise RuntimeError(f"Torchdynamo backend {self.args.torchdynamo} is not supported.") + + self.ctx_manager_torchdynamo = get_ctx() + else: + self.ctx_manager_torchdynamo = contextlib.nullcontext() + def add_callback(self, callback): """ Add a callback to the current list of [`~transformer.TrainerCallback`]. @@ -2290,7 +2321,7 @@ def torchdynamo_smart_context_manager(self): """ A helper wrapper that creates an appropriate context manager for `torchdynamo`. """ - return self.args.ctx_manager_torchdynamo + return self.ctx_manager_torchdynamo def autocast_smart_context_manager(self): """ diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 11ca41a8e871..012ae3223fbf 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -42,10 +42,8 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, - is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, - is_torchdynamo_available, logging, torch_required, ) @@ -1221,13 +1219,6 @@ def __post_init__(self): FutureWarning, ) - if self.torchdynamo: - if not is_torchdynamo_available(): - raise RuntimeError("Torchdynamo is not installed.") - self.ctx_manager_torchdynamo = get_torchdynamo_ctx(self.torchdynamo) - else: - self.ctx_manager_torchdynamo = contextlib.nullcontext() - def __str__(self): self_as_dict = asdict(self) From 81e13a50b8aaf4eaf63ad565cbac2458a2ad960a Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 12 Jul 2022 23:30:40 -0700 Subject: [PATCH 22/24] remove redundancy --- src/transformers/training_args.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 012ae3223fbf..833dc174c375 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -23,7 +23,6 @@ from typing import Any, Dict, List, Optional, Union from .debug_utils import DebugOption -from .integrations import get_torchdynamo_ctx from .trainer_utils import ( EvaluationStrategy, FSDPOption, From 227727845ed781f7b49388fc74110e445c119f69 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 12 Jul 2022 23:45:26 -0700 Subject: [PATCH 23/24] isort --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 336a10e00415..d5dd61ac768b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -141,9 +141,9 @@ is_ipex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, + is_torch_tensorrt_fx_available, is_torch_tpu_available, is_torchdynamo_available, - is_torch_tensorrt_fx_available, logging, ) from .utils.generic import ContextManagers From 8b65516d14e5a0371e5821a99633666bd752ec8a Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 13 Jul 2022 08:55:40 -0700 Subject: [PATCH 24/24] replace self.args with args --- src/transformers/trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d5dd61ac768b..dcadc02718cf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -600,7 +600,7 @@ def __init__( self._memory_tracker.stop_and_update_metrics() # torchdynamo - if self.args.torchdynamo: + if args.torchdynamo: if not is_torchdynamo_available(): raise RuntimeError("Torchdynamo is not installed.") import torchdynamo @@ -609,20 +609,20 @@ def __init__( def get_ctx(): # Normal - if self.args.torchdynamo == "eager": + if args.torchdynamo == "eager": return torchdynamo.optimize("eager") - elif self.args.torchdynamo == "nvfuser": + elif args.torchdynamo == "nvfuser": return torchdynamo.optimize(aot_autograd_speedup_strategy) # TensorRT - if self.args.torchdynamo in ["fx2trt-fp16", "fx2trt"]: + if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]: if not is_torch_tensorrt_fx_available(): raise RuntimeError("Torch-TensorRT FX path is not installed.") - if self.args.torchdynamo == "fx2trt-fp16": + if args.torchdynamo == "fx2trt-fp16": return torchdynamo.optimize(backends.fx2trt_compiler_fp16) - elif self.args.torchdynamo == "fx2trt": + elif args.torchdynamo == "fx2trt": return torchdynamo.optimize(backends.fx2trt_compiler) else: - raise RuntimeError(f"Torchdynamo backend {self.args.torchdynamo} is not supported.") + raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.") self.ctx_manager_torchdynamo = get_ctx() else: